Multi-output Gaussian Processes in PyMC [GSoC Week 07-09]
A personal note on the progress of incoporating Multi-output Gaussian Processes (MOGPs) into PyMC. Week 07-09 focus on implementing ICM and LCM using Hadamard (element-wise) product.
This work is supported by GSoC, NumFOCUS, and PyMC team
Given input data $x$ and different outputs $o$, the ICM kernel $K$ is calculated by Hadamard or element-wise product: $$ K = K_1(x, x') * K_2(o, o') $$
Where $K_2(o, o')$ is broadcast into the shape of input data $K_2(x, x')$ using Coregion kernel.
NOTE: This Hadamard product can work with same input data or different input data.
import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
# set the seed
np.random.seed(1)
from multi_ouputs import build_XY, ICM, LCM, MultiMarginal
from mo import MultiOutputMarginal
import math
%matplotlib inline
%load_ext autoreload
%reload_ext autoreload
%autoreload 2
N = 50
train_x = np.linspace(0, 1, N)
train_y = np.stack([
np.sin(train_x * (2 * math.pi)) + np.random.randn(len(train_x)) * 0.2,
np.cos(train_x * (2 * math.pi)) + np.random.randn(len(train_x)) * 0.2,
np.cos(train_x * (1 * math.pi)) + np.random.randn(len(train_x)) * 0.1,
], -1)
train_x.shape, train_y.shape
fig, ax = plt.subplots(1,1, figsize=(12,5))
ax.scatter(train_x, train_y[:,0])
ax.scatter(train_x, train_y[:,1])
ax.scatter(train_x, train_y[:,2])
plt.legend(["y1", "y2", "y3"])
train_x.shape, train_y.shape
np.vstack([train_y[:,0], train_y[:,1], train_y[:,2]]).shape
x = train_x.reshape(-1,1)
X, Y, _ = build_XY([x,x,x],
[train_y[:,0].reshape(-1,1),
train_y[:,1].reshape(-1,1),
train_y[:,2].reshape(-1,1)])
x.shape, X.shape, Y.shape
M = 100
x_new = np.linspace(-0.5, 1.5, M)[:, None]
X_new, _, _ = build_XY([x_new, x_new, x_new])
X_new.shape
import aesara.tensor as at
with pm.Model() as model:
ell = pm.Gamma("ell", alpha=2, beta=0.5)
eta = pm.Gamma("eta", alpha=3, beta=1)
cov = eta**2 * pm.gp.cov.ExpQuad(input_dim=2, ls=ell, active_dims=[0])
W = np.random.rand(3,2) # (n_outputs, w_rank)
kappa = np.random.rand(3)
B = pm.Deterministic('B', at.dot(W, W.T) + at.diag(kappa))
sigma = pm.HalfNormal("sigma", sigma=3)
mogp = MultiOutputMarginal(means=0, kernels=[cov], input_dim=2, active_dims=[1], num_outputs=3, B=B)
y_ = mogp.marginal_likelihood("f", X, Y.squeeze(), noise=sigma)
pm.model_to_graphviz(model)
%%time
with model:
gp_trace = pm.sample(500, chains=1)
%%time
with model:
preds = mogp.conditional("preds", X_new)
gp_samples = pm.sample_posterior_predictive(gp_trace, var_names=['preds'], random_seed=42)
pm.model_to_graphviz(model)
from pymc.gp.util import plot_gp_dist
f_pred = gp_samples.posterior_predictive["preds"].sel(chain=0)
fig, axes = plt.subplots(3,1, figsize=(10,10))
for idx in range(3):
plot_gp_dist(axes[idx], f_pred[:,n_points*idx:n_points*(idx+1)],
X_new[n_points*idx:n_points*(idx+1),0],
palette="Blues", fill_alpha=0.5, samples_alpha=0.1)
axes[idx].plot(x, train_y[:,idx], 'ok', ms=3, alpha=0.5, label="Data 1");
az.summary(gp_trace)
with pm.Model() as model:
# Priors
ell = pm.Gamma("ell", alpha=2, beta=0.5, shape=2)
eta = pm.Gamma("eta", alpha=3, beta=1, shape=2)
kernels = [pm.gp.cov.ExpQuad, pm.gp.cov.Matern32]
sigma = pm.HalfNormal("sigma", sigma=3)
# Define a list of covariance functions
cov_list = [eta[idx] ** 2 * kernel(2,ls=ell[idx], active_dims=[0])
for idx, kernel in enumerate(kernels)]
# Define a Multi-output GP
mogp = MultiOutputMarginal(means=0, kernels=cov_list, input_dim=2, active_dims=[1], num_outputs=3)
y_ = mogp.marginal_likelihood("f", X, Y.squeeze(), noise=sigma)
pm.model_to_graphviz(model)
# x1, y1
# x2, y2
# x3, y3
%%time
with model:
gp_trace = pm.sample(500, chains=1)
%%time
with model:
preds = mogp.conditional("preds", X_new)
gp_samples = pm.sample_posterior_predictive(gp_trace, var_names=['preds'], random_seed=42)
pm.model_to_graphviz(model)
from pymc.gp.util import plot_gp_dist
f_pred = gp_samples.posterior_predictive["preds"].sel(chain=0)
fig, axes = plt.subplots(3,1, figsize=(10,10))
for idx in range(3):
plot_gp_dist(axes[idx], f_pred[:,n_points*idx:n_points*(idx+1)],
X_new[n_points*idx:n_points*(idx+1),0],
palette="Blues", fill_alpha=0.5, samples_alpha=0.1)
axes[idx].plot(x, train_y[:,idx], 'ok', ms=3, alpha=0.5, label="Data 1");
az.summary(gp_trace)
az.plot_trace(gp_trace);
plt.tight_layout()
%load_ext watermark
%watermark -n -u -v -iv -w