diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index c44616e0d9..708f4b8ec2 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -446,17 +446,22 @@ def posterior( mvn = self(X) mvn = self._apply_noise(X=X, mvn=mvn, observation_noise=observation_noise) if self._num_outputs > 1: - mean_x = mvn.mean - covar_x = mvn.lazy_covariance_matrix - output_indices = output_indices or range(self._num_outputs) - mvns = [ - MultivariateNormal( - mean_x.select(dim=output_dim_idx, index=t), - covar_x[(slice(None),) * output_dim_idx + (t,)], + if torch.jit.is_tracing(): + mvn = MultitaskMultivariateNormal.from_batch_mvn( + mvn, task_dim=output_dim_idx ) - for t in output_indices - ] - mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) + else: + mean_x = mvn.mean + covar_x = mvn.lazy_covariance_matrix + output_indices = output_indices or range(self._num_outputs) + mvns = [ + MultivariateNormal( + mean_x.select(dim=output_dim_idx, index=t), + covar_x[(slice(None),) * output_dim_idx + (t,)], + ) + for t in output_indices + ] + mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) posterior = GPyTorchPosterior(distribution=mvn) if hasattr(self, "outcome_transform"): diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index 5b5ba99180..b1adeb6ea6 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -37,6 +37,7 @@ from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.models import ExactGP, IndependentModelList +from gpytorch.settings import trace_mode from torch import Tensor @@ -410,6 +411,32 @@ def test_posterior_transform(self): post = model.posterior(torch.rand(3, 2, **tkwargs), posterior_transform=post_tf) self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs))) + def test_posterior_in_trace_mode(self): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X = torch.rand(5, 1, **tkwargs) + train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1) + model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y) + + class MeanVarModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + # get the model posterior + posterior = self.model.posterior(x, observation_noise=True) + mean = posterior.mean.detach() + std = posterior.variance.sqrt().detach() + return mean, std + + wrapped_model = MeanVarModelWrapper(model) + with torch.no_grad(), trace_mode(): + X_test = torch.rand(3, 1, **tkwargs) + wrapped_model(X_test) # Compute caches + traced_model = torch.jit.trace(wrapped_model, X_test) + mean, std = traced_model(X_test) + self.assertEqual(mean.shape, torch.Size([3, 2])) + class TestModelListGPyTorchModel(BotorchTestCase): def test_model_list_gpytorch_model(self):