From 9a711a86f4c59c1886c2f2c858c5af08d4f70579 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Fri, 6 Sep 2024 14:09:18 -0700 Subject: [PATCH] Affine input transforms should error with data of incorrect dimension, even in eval mode (#2510) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2510 Context: https://github.com/pytorch/botorch/issues/2509 gives a clear overview This PR: * Checks the shape of the `X` provided to an `AffineInputTransform` when it transforms the data, regardless of whether it is updating the coefficients. Makes some unrelated changes: * Fixes the example in the docstring for `batched_multi_output_to_single_output` * fixes an incorrect shape in `test_approximate_gp` * Makes data and transform batch shapes match in `TestConverters`, since those usages will now (appropriately) error Differential Revision: D62318530 --- botorch/models/converter.py | 4 ++-- botorch/models/transforms/input.py | 4 ++-- test/models/test_approximate_gp.py | 2 +- test/models/test_converter.py | 15 +++++++++++---- test/models/transforms/test_input.py | 17 ++++++++++++----- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/botorch/models/converter.py b/botorch/models/converter.py index 1c75e6b2cc..9d898289f0 100644 --- a/botorch/models/converter.py +++ b/botorch/models/converter.py @@ -388,8 +388,8 @@ def batched_multi_output_to_single_output( Example: >>> train_X = torch.rand(5, 2) >>> train_Y = torch.rand(5, 2) - >>> batch_mo_gp = SingleTaskGP(train_X, train_Y) - >>> batch_so_gp = batched_multioutput_to_single_output(batch_gp) + >>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None) + >>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp) """ warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2) was_training = batch_mo_model.training diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index d56a1db2e3..d235b8005a 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor: Returns: A `batch_shape x n x d`-dim tensor of transformed inputs. """ + self._check_shape(X) if self.learn_coefficients and self.training: - self._check_shape(X) self._update_coefficients(X) self._to(X) return (X - self.offset) / self.coefficient @@ -467,7 +467,7 @@ def _check_shape(self, X: Tensor) -> None: if X.ndim < n: raise ValueError( f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate" - f" , but has {X.ndim}." + f" , but has {X.ndim}. Expected shape (..., *{self.batch_shape}, n, {self.coefficient.shape[-1]}), got `{X.shape}`." ) torch.broadcast_shapes(self.coefficient.shape, self.offset.shape, X.shape) diff --git a/test/models/test_approximate_gp.py b/test/models/test_approximate_gp.py index e7d0b7f980..ed7cba18cf 100644 --- a/test/models/test_approximate_gp.py +++ b/test/models/test_approximate_gp.py @@ -327,5 +327,5 @@ def test_input_transform(self) -> None: model.likelihood, model.model, num_data=train_X.shape[-2] ) fit_gpytorch_mll(mll) - post = model.posterior(torch.tensor([train_X.mean()])) + post = model.posterior(torch.tensor([[train_X.mean()]])) self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3) diff --git a/test/models/test_converter.py b/test/models/test_converter.py index 47d43a53b6..ef6fad033b 100644 --- a/test/models/test_converter.py +++ b/test/models/test_converter.py @@ -278,13 +278,21 @@ def test_model_list_to_batched(self): batch_shape=torch.Size([3]), ) gp1_ = SingleTaskGP( - train_X, train_Y1, input_transform=input_tf2, outcome_transform=None + train_X=train_X.unsqueeze(0), + train_Y=train_Y1.unsqueeze(0), + input_transform=input_tf2, + outcome_transform=None, ) gp2_ = SingleTaskGP( - train_X, train_Y2, input_transform=input_tf2, outcome_transform=None + train_X=train_X.unsqueeze(0), + train_Y=train_Y2.unsqueeze(0), + input_transform=input_tf2, + outcome_transform=None, ) list_gp = ModelListGP(gp1_, gp2_) - with self.assertRaises(UnsupportedError): + with self.assertRaisesRegex( + UnsupportedError, "Batched input_transforms are not supported." + ): model_list_to_batched(list_gp) # test outcome transform @@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self): bounds=torch.tensor( [[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype ), - batch_shape=torch.Size([2]), ) batched_mo_model = SingleTaskGP( train_X, train_Y, input_transform=input_tf, outcome_transform=None diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 8efd24c1f9..5537c72cd3 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -228,7 +228,10 @@ def test_normalize(self) -> None: self.assertTrue(nlz.mins.dtype == other_dtype) # test incompatible dimensions of specified bounds bounds = torch.zeros(2, 3, device=self.device, dtype=dtype) - with self.assertRaises(BotorchTensorDimensionError): + with self.assertRaisesRegex( + BotorchTensorDimensionError, + "Dimensions of provided `bounds` are incompatible", + ): Normalize(d=2, bounds=bounds) # test jitter @@ -266,7 +269,12 @@ def test_normalize(self) -> None: # test errors on wrong shape nlz = Normalize(d=2, batch_shape=batch_shape) X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype) - with self.assertRaises(BotorchTensorDimensionError): + expected_msg = "Wrong input dimension. Received 1, expected 2." + with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg): + nlz(X) + # Same error in eval mode + nlz.eval() + with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg): nlz(X) # fixed bounds @@ -328,9 +336,8 @@ def test_normalize(self) -> None: [X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]], dim=-2, )[..., indices] - self.assertTrue( - torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4) - ) + self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4) + # test errors on wrong shape nlz = Normalize(d=2, batch_shape=batch_shape) X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)