Skip to content

Commit 33e11f4

Browse files
esantorellafacebook-github-bot
authored andcommitted
Affine input transforms should error with data of incorrect dimension, even in eval mode (#2510)
Summary: Context: #2509 gives a clear overview This PR: * Checks the shape of the `X` provided to an `AffineInputTransform` when it transforms the data, regardless of whether it is updating the coefficients. Makes some unrelated changes: * Fixes the example in the docstring for `batched_multi_output_to_single_output` * fixes an incorrect shape in `test_approximate_gp` * Makes data and transform batch shapes match in `TestConverters`, since those usages will now (appropriately) error Pull Request resolved: #2510 Reviewed By: saitcakmak Differential Revision: D62318530 Pulled By: esantorella fbshipit-source-id: eaa8b0410c49b17d6abbe1391bbb0750313aea23
1 parent 4dc1271 commit 33e11f4

File tree

7 files changed

+36
-20
lines changed

7 files changed

+36
-20
lines changed

botorch/models/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def batched_multi_output_to_single_output(
388388
Example:
389389
>>> train_X = torch.rand(5, 2)
390390
>>> train_Y = torch.rand(5, 2)
391-
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y)
392-
>>> batch_so_gp = batched_multioutput_to_single_output(batch_gp)
391+
>>> batch_mo_gp = SingleTaskGP(train_X, train_Y, outcome_transform=None)
392+
>>> batch_so_gp = batched_multi_output_to_single_output(batch_mo_gp)
393393
"""
394394
warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning, stacklevel=2)
395395
was_training = batch_mo_model.training

botorch/models/transforms/input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def _transform(self, X: Tensor) -> Tensor:
412412
Returns:
413413
A `batch_shape x n x d`-dim tensor of transformed inputs.
414414
"""
415+
self._check_shape(X)
415416
if self.learn_coefficients and self.training:
416-
self._check_shape(X)
417417
self._update_coefficients(X)
418418
self._to(X)
419419
return (X - self.offset) / self.coefficient

test/acquisition/test_proximal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def test_proximal(self):
7878
proximal_test_X = test_X.clone()
7979
if transformed_weighting:
8080
if input_transform is not None:
81-
last_X = input_transform(train_X[-1])
81+
last_X = input_transform(train_X[-1].unsqueeze(0))
8282
proximal_test_X = input_transform(test_X)
8383

8484
mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
@@ -105,7 +105,7 @@ def test_proximal(self):
105105
proximal_test_X = test_X.clone()
106106
if transformed_weighting:
107107
if input_transform is not None:
108-
last_X = input_transform(train_X[-1])
108+
last_X = input_transform(train_X[-1].unsqueeze(0))
109109
proximal_test_X = input_transform(test_X)
110110

111111
mv_normal = MultivariateNormal(last_X, torch.diag(proximal_weights))
@@ -122,7 +122,7 @@ def test_proximal(self):
122122
proximal_test_X = test_X.clone()
123123
if transformed_weighting:
124124
if input_transform is not None:
125-
last_X = input_transform(train_X[-1])
125+
last_X = input_transform(train_X[-1].unsqueeze(0))
126126
proximal_test_X = input_transform(test_X)
127127

128128
ei = EI(test_X)
@@ -143,7 +143,7 @@ def test_proximal(self):
143143
proximal_test_X = test_X.clone()
144144
if transformed_weighting:
145145
if input_transform is not None:
146-
last_X = input_transform(train_X[-1])
146+
last_X = input_transform(train_X[-1].unsqueeze(0))
147147
proximal_test_X = input_transform(test_X)
148148

149149
qEI_prox = ProximalAcquisitionFunction(

test/models/test_approximate_gp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,5 @@ def test_input_transform(self) -> None:
327327
model.likelihood, model.model, num_data=train_X.shape[-2]
328328
)
329329
fit_gpytorch_mll(mll)
330-
post = model.posterior(torch.tensor([train_X.mean()]))
330+
post = model.posterior(torch.tensor([[train_X.mean()]]))
331331
self.assertAllClose(post.mean[0][0], y.mean(), atol=1e-3, rtol=1e-3)

test/models/test_converter.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,21 @@ def test_model_list_to_batched(self):
278278
batch_shape=torch.Size([3]),
279279
)
280280
gp1_ = SingleTaskGP(
281-
train_X, train_Y1, input_transform=input_tf2, outcome_transform=None
281+
train_X=train_X.unsqueeze(0),
282+
train_Y=train_Y1.unsqueeze(0),
283+
input_transform=input_tf2,
284+
outcome_transform=None,
282285
)
283286
gp2_ = SingleTaskGP(
284-
train_X, train_Y2, input_transform=input_tf2, outcome_transform=None
287+
train_X=train_X.unsqueeze(0),
288+
train_Y=train_Y2.unsqueeze(0),
289+
input_transform=input_tf2,
290+
outcome_transform=None,
285291
)
286292
list_gp = ModelListGP(gp1_, gp2_)
287-
with self.assertRaises(UnsupportedError):
293+
with self.assertRaisesRegex(
294+
UnsupportedError, "Batched input_transforms are not supported."
295+
):
288296
model_list_to_batched(list_gp)
289297

290298
# test outcome transform
@@ -457,7 +465,6 @@ def test_batched_multi_output_to_single_output(self):
457465
bounds=torch.tensor(
458466
[[-1.0, -1.0], [1.0, 1.0]], device=self.device, dtype=dtype
459467
),
460-
batch_shape=torch.Size([2]),
461468
)
462469
batched_mo_model = SingleTaskGP(
463470
train_X, train_Y, input_transform=input_tf, outcome_transform=None

test/models/transforms/test_input.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def test_normalize(self) -> None:
228228
self.assertTrue(nlz.mins.dtype == other_dtype)
229229
# test incompatible dimensions of specified bounds
230230
bounds = torch.zeros(2, 3, device=self.device, dtype=dtype)
231-
with self.assertRaises(BotorchTensorDimensionError):
231+
with self.assertRaisesRegex(
232+
BotorchTensorDimensionError,
233+
"Dimensions of provided `bounds` are incompatible",
234+
):
232235
Normalize(d=2, bounds=bounds)
233236

234237
# test jitter
@@ -266,7 +269,12 @@ def test_normalize(self) -> None:
266269
# test errors on wrong shape
267270
nlz = Normalize(d=2, batch_shape=batch_shape)
268271
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
269-
with self.assertRaises(BotorchTensorDimensionError):
272+
expected_msg = "Wrong input dimension. Received 1, expected 2."
273+
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
274+
nlz(X)
275+
# Same error in eval mode
276+
nlz.eval()
277+
with self.assertRaisesRegex(BotorchTensorDimensionError, expected_msg):
270278
nlz(X)
271279

272280
# fixed bounds
@@ -328,9 +336,8 @@ def test_normalize(self) -> None:
328336
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
329337
dim=-2,
330338
)[..., indices]
331-
self.assertTrue(
332-
torch.allclose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
333-
)
339+
self.assertAllClose(nlz.bounds, expected_bounds, atol=1e-4, rtol=1e-4)
340+
334341
# test errors on wrong shape
335342
nlz = Normalize(d=2, batch_shape=batch_shape)
336343
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)

test_community/models/test_gp_regression_multisource.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def _get_model_and_data(
7676
None if train_Yvar else get_gaussian_likelihood_with_gamma_prior()
7777
),
7878
}
79-
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
79+
with warnings.catch_warnings():
80+
warnings.simplefilter("ignore", category=OptimizationWarning)
81+
model = SingleTaskAugmentedGP(**model_kwargs, **extra_model_kwargs)
8082
return model, model_kwargs
8183

8284
def test_data_init(self):
@@ -139,8 +141,8 @@ def test_get_reliable_observation(self):
139141
self.assertListEqual(res.tolist(), true_res.tolist())
140142

141143
def test_gp(self):
142-
bounds = torch.tensor([[-1.0], [1.0]])
143144
d = 5
145+
bounds = torch.stack((torch.full((d - 1,), -1), torch.ones(d - 1)))
144146
for batch_shape, dtype, use_octf, use_intf, train_Yvar in itertools.product(
145147
(torch.Size(), torch.Size([2])),
146148
(torch.float, torch.double),
@@ -151,7 +153,7 @@ def test_gp(self):
151153
tkwargs = {"device": self.device, "dtype": dtype}
152154
octf = Standardize(m=1, batch_shape=torch.Size()) if use_octf else None
153155
intf = (
154-
Normalize(d=1, bounds=bounds.to(**tkwargs), transform_on_train=True)
156+
Normalize(d=d - 1, bounds=bounds.to(**tkwargs), transform_on_train=True)
155157
if use_intf
156158
else None
157159
)

0 commit comments

Comments
 (0)