From e7539db0fb2007236a3f75f41c47ed1e6427ae98 Mon Sep 17 00:00:00 2001 From: Sai Aakash <64794820+SaiAakash@users.noreply.github.com> Date: Thu, 24 Oct 2024 10:21:17 -0700 Subject: [PATCH] Fix posterior method in `BatchedMultiOutputGPyTorchModel` for tracing JIT (#2592) Summary: ## Motivation Fixes https://github.com/pytorch/botorch/issues/2591. Generates the MTMVN for the independent task case slightly differently when jit traced. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: https://github.com/pytorch/botorch/pull/2592 Test Plan: A unit test `test_posterior_in_trace_mode` has been added to test_gpytorch.py ## Related PRs NA Reviewed By: saitcakmak, Balandat Differential Revision: D64903356 Pulled By: sdaulton fbshipit-source-id: 32fa2f108e99683d92344e31123a6bd07cc4113b --- botorch/models/gpytorch.py | 25 +++++++++++++++---------- test/models/test_gpytorch.py | 27 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index c44616e0d9..708f4b8ec2 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -446,17 +446,22 @@ def posterior( mvn = self(X) mvn = self._apply_noise(X=X, mvn=mvn, observation_noise=observation_noise) if self._num_outputs > 1: - mean_x = mvn.mean - covar_x = mvn.lazy_covariance_matrix - output_indices = output_indices or range(self._num_outputs) - mvns = [ - MultivariateNormal( - mean_x.select(dim=output_dim_idx, index=t), - covar_x[(slice(None),) * output_dim_idx + (t,)], + if torch.jit.is_tracing(): + mvn = MultitaskMultivariateNormal.from_batch_mvn( + mvn, task_dim=output_dim_idx ) - for t in output_indices - ] - mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) + else: + mean_x = mvn.mean + covar_x = mvn.lazy_covariance_matrix + output_indices = output_indices or range(self._num_outputs) + mvns = [ + MultivariateNormal( + mean_x.select(dim=output_dim_idx, index=t), + covar_x[(slice(None),) * output_dim_idx + (t,)], + ) + for t in output_indices + ] + mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) posterior = GPyTorchPosterior(distribution=mvn) if hasattr(self, "outcome_transform"): diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index 5b5ba99180..b1adeb6ea6 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -37,6 +37,7 @@ from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.models import ExactGP, IndependentModelList +from gpytorch.settings import trace_mode from torch import Tensor @@ -410,6 +411,32 @@ def test_posterior_transform(self): post = model.posterior(torch.rand(3, 2, **tkwargs), posterior_transform=post_tf) self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs))) + def test_posterior_in_trace_mode(self): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X = torch.rand(5, 1, **tkwargs) + train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1) + model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y) + + class MeanVarModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + # get the model posterior + posterior = self.model.posterior(x, observation_noise=True) + mean = posterior.mean.detach() + std = posterior.variance.sqrt().detach() + return mean, std + + wrapped_model = MeanVarModelWrapper(model) + with torch.no_grad(), trace_mode(): + X_test = torch.rand(3, 1, **tkwargs) + wrapped_model(X_test) # Compute caches + traced_model = torch.jit.trace(wrapped_model, X_test) + mean, std = traced_model(X_test) + self.assertEqual(mean.shape, torch.Size([3, 2])) + class TestModelListGPyTorchModel(BotorchTestCase): def test_model_list_gpytorch_model(self):