From 47ea8bd51f39b7e5ba93ccd491512b8b45ffbe4b Mon Sep 17 00:00:00 2001 From: SaiAakash Date: Wed, 23 Oct 2024 16:48:40 +0100 Subject: [PATCH 1/3] check for tracing adn generate mtmvn differently --- botorch/models/gpytorch.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 3bc2059c14..3c2df163ec 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -48,6 +48,7 @@ from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator from torch import Tensor +from torch._C import _get_tracing_state if TYPE_CHECKING: from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover @@ -446,17 +447,20 @@ def posterior( mvn = self(X) mvn = self._apply_noise(X=X, mvn=mvn, observation_noise=observation_noise) if self._num_outputs > 1: - mean_x = mvn.mean - covar_x = mvn.lazy_covariance_matrix - output_indices = output_indices or range(self._num_outputs) - mvns = [ - MultivariateNormal( - mean_x.select(dim=output_dim_idx, index=t), - covar_x[(slice(None),) * output_dim_idx + (t,)], - ) - for t in output_indices - ] - mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) + if _get_tracing_state(): + mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0) + else: + mean_x = mvn.mean + covar_x = mvn.lazy_covariance_matrix + output_indices = output_indices or range(self._num_outputs) + mvns = [ + MultivariateNormal( + mean_x.select(dim=output_dim_idx, index=t), + covar_x[(slice(None),) * output_dim_idx + (t,)], + ) + for t in output_indices + ] + mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns) posterior = GPyTorchPosterior(distribution=mvn) if hasattr(self, "outcome_transform"): From c7db5542685f22a59758f17cdf5986615a633fbb Mon Sep 17 00:00:00 2001 From: SaiAakash Date: Wed, 23 Oct 2024 18:33:54 +0100 Subject: [PATCH 2/3] added unit test --- test/models/test_gpytorch.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/models/test_gpytorch.py b/test/models/test_gpytorch.py index 6727b26e91..39d6aaf9bf 100644 --- a/test/models/test_gpytorch.py +++ b/test/models/test_gpytorch.py @@ -38,6 +38,7 @@ from gpytorch.likelihoods import GaussianLikelihood from gpytorch.means import ConstantMean from gpytorch.models import ExactGP, IndependentModelList +from gpytorch.settings import trace_mode from torch import Tensor @@ -411,6 +412,32 @@ def test_posterior_transform(self): post = model.posterior(torch.rand(3, 2, **tkwargs), posterior_transform=post_tf) self.assertTrue(torch.equal(post.mean, torch.zeros(3, 1, **tkwargs))) + def test_posterior_in_trace_mode(self): + tkwargs = {"device": self.device, "dtype": torch.double} + train_X = torch.rand(5, 1, **tkwargs) + train_Y = torch.cat([torch.sin(train_X), torch.cos(train_X)], dim=-1) + model = SimpleBatchedMultiOutputGPyTorchModel(train_X, train_Y) + + class MeanVarModelWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + # get the model posterior + posterior = self.model.posterior(x, observation_noise=True) + mean = posterior.mean.detach() + std = posterior.variance.sqrt().detach() + return mean, std + + wrapped_model = MeanVarModelWrapper(model) + with torch.no_grad(), trace_mode(): + X_test = torch.rand(3, 1, **tkwargs) + wrapped_model(X_test) # Compute caches + traced_model = torch.jit.trace(wrapped_model, X_test) + mean, std = traced_model(X_test) + self.assertEqual(mean.shape, torch.Size([3, 2])) + class TestModelListGPyTorchModel(BotorchTestCase): def test_model_list_gpytorch_model(self): From b40b548203b325914d69e7b66c3f89c4eb113f86 Mon Sep 17 00:00:00 2001 From: SaiAakash Date: Thu, 24 Oct 2024 15:51:11 +0100 Subject: [PATCH 3/3] set task_dim to output_dim_idx; use public API to check for trace mode --- botorch/models/gpytorch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/botorch/models/gpytorch.py b/botorch/models/gpytorch.py index 3c2df163ec..a7278d9b69 100644 --- a/botorch/models/gpytorch.py +++ b/botorch/models/gpytorch.py @@ -48,7 +48,6 @@ from gpytorch.likelihoods.gaussian_likelihood import FixedNoiseGaussianLikelihood from linear_operator.operators import BlockDiagLinearOperator, CatLinearOperator from torch import Tensor -from torch._C import _get_tracing_state if TYPE_CHECKING: from botorch.posteriors.posterior_list import PosteriorList # pragma: no cover @@ -447,8 +446,10 @@ def posterior( mvn = self(X) mvn = self._apply_noise(X=X, mvn=mvn, observation_noise=observation_noise) if self._num_outputs > 1: - if _get_tracing_state(): - mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0) + if torch.jit.is_tracing(): + mvn = MultitaskMultivariateNormal.from_batch_mvn( + mvn, task_dim=output_dim_idx + ) else: mean_x = mvn.mean covar_x = mvn.lazy_covariance_matrix