Skip to content

Commit

Permalink
Affine input transforms should error with data of incorrect dimension…
Browse files Browse the repository at this point in the history
…, even in eval mode (#2510)

Summary:
Pull Request resolved: #2510

Context: #2509 gives a clear overview

This PR:
* Checks the shape of the `X` provided to an `AffineInputTransform` when it transforms the data, regardless of whether it is updating the coefficients.

Makes some unrelated changes:
* Fixes the example in the docstring for `batched_multi_output_to_single_output`
* fixes an incorrect shape in `test_approximate_gp`
* Makes data and transform batch shapes match in `TestConverters`, since those usages will now (appropriately) error

Differential Revision: D62318530
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 6, 2024
1 parent 1417189 commit 9a711a8
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions botorch/models/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def batched_multi_output_to_single_output(
Example:
>>> train_X = torch.rand(5, 2)
>>> train_Y = torch.rand(5, 2)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None)
>>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp)
"""
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
was_training = batch_mo_model.training
Expand Down
4 changes: 2 additions & 2 deletions botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor:
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
self._check_shape(X)
if self.learn_coefficients and self.training:
self._check_shape(X)
self._update_coefficients(X)
self._to(X)
return (X - self.offset) / self.coefficient
Expand Down Expand Up @@ -467,7 +467,7 @@ def _check_shape(self, X: Tensor) -> None:
if X.ndim < n:
raise ValueError(
f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate"
f" , but has {X.ndim}."
f" , but has {X.ndim}. Expected shape (..., *{self.batch_shape}, n, {self.coefficient.shape[-1]}), got `{X.shape}`."
)

torch.broadcast_shapes(self.coefficient.shape, self.offset.shape, X.shape)
Expand Down
2 changes: 1 addition & 1 deletion test/models/test_approximate_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,5 +327,5 @@ def test_input_transform(self) -> None:
model.likelihood, model.model, num_data=train_X.shape[-2]
)
fit_gpytorch_mll(mll)
post = model.posterior(torch.tensor([train_X.mean()]))
post = model.posterior(torch.tensor([[train_X.mean()]]))
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3)
15 changes: 11 additions & 4 deletions test/models/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,21 @@ def test_model_list_to_batched(self):
batch_shape=torch.Size([3]),
)
gp1_ = SingleTaskGP(
train_X, train_Y1, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y1.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
gp2_ = SingleTaskGP(
train_X, train_Y2, input_transform=input_tf2, outcome_transform=None
train_X=train_X.unsqueeze(0),
train_Y=train_Y2.unsqueeze(0),
input_transform=input_tf2,
outcome_transform=None,
)
list_gp = ModelListGP(gp1_, gp2_)
with self.assertRaises(UnsupportedError):
with self.assertRaisesRegex(
UnsupportedError, "Batched input_transforms are not supported."
):
model_list_to_batched(list_gp)

# test outcome transform
Expand Down Expand Up @@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self):
bounds=torch.tensor(
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
),
batch_shape=torch.Size([2]),
)
batched_mo_model = SingleTaskGP(
train_X, train_Y, input_transform=input_tf, outcome_transform=None
Expand Down
17 changes: 12 additions & 5 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ def test_normalize(self) -> None:
self.assertTrue(nlz.mins.dtype == other_dtype)
# test incompatible dimensions of specified bounds
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
with self.assertRaisesRegex(
BotorchTensorDimensionError,
"Dimensions of provided `bounds` are incompatible",
):
Normalize(d=2, bounds=bounds)

# test jitter
Expand Down Expand Up @@ -266,7 +269,12 @@ def test_normalize(self) -> None:
# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
expected_msg = "Wrong input dimension. Received 1, expected 2."
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)
# Same error in eval mode
nlz.eval()
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
nlz(X)

# fixed bounds
Expand Down Expand Up @@ -328,9 +336,8 @@ def test_normalize(self) -> None:
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
dim=-2,
)[..., indices]
self.assertTrue(
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
)
self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)

# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
Expand Down

0 comments on commit 9a711a8

Please sign in to comment.