Skip to content

Commit 1618797

Browse files
esantorellafacebook-github-bot
authored andcommitted
Allow broadcasting across dimensions in eval mode; always require X to be at least 2d
Summary: Context: A discussion on allowable shapes for transforms concluded: * We should not allow for broadcasting across the -1 dimension, so the first check in _check_shape should always happen. * The shapes always need to be broadcastable, so the torch.broadcast_shapes check in _check_shape should always happen. * We want to allow for broadcasting across the batch dimension in eval model, so the check that X has dimension of at least len(batch_shape) + 2 should only happen in training mode. * For clarity, we should disallow 1d X, even if broadcastable. BoTorch tends to be strict about requiring explicit dimensions, e.g. GPyTorchModel._validate_tensor_args, and that's a good thing because confusion about tensor dimensions causes a lot of pain. This diff: * Only checks that X has number of dimensions equal to 2 + the number of batch dimensions in training mode. * Disallows <2d X. Differential Revision: D62404492
1 parent 33e11f4 commit 1618797

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

botorch/models/transforms/input.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,13 @@ def _check_shape(self, X: Tensor) -> None:
462462
f"Wrong input dimension. Received {X.size(-1)}, "
463463
f"expected {self.offset.size(-1)}."
464464
)
465+
if X.ndim < 2:
466+
raise BotorchTensorDimensionError(
467+
f"`X` must have at least 2 dimensions, but has {X.ndim}."
468+
)
465469

466470
n = len(self.batch_shape) + 2
467-
if X.ndim < n:
471+
if self.training and X.ndim < n:
468472
raise ValueError(
469473
f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate"
470474
f" , but has {X.ndim}."

test/models/transforms/test_input.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,19 @@ def test_normalize(self) -> None:
240240
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
241241
X = X.to(self.device)
242242
self.assertEqual(torch.isfinite(nlz(X)).sum(), X.numel())
243-
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
243+
with self.assertRaisesRegex(
244+
BotorchTensorDimensionError, r"must have at least 2 dimensions"
245+
):
244246
nlz(torch.randn(X.shape[-1], dtype=dtype))
245247

248+
# using unbatched X to train batched transform
249+
nlz = Normalize(d=2, min_range=1e-4, batch_shape=torch.Size([3]))
250+
X = torch.rand(4, 2)
251+
with self.assertRaisesRegex(
252+
ValueError, "must have at least 3 dimensions, 1 batch and 2 innate"
253+
):
254+
nlz(X)
255+
246256
# basic usage
247257
for batch_shape in (torch.Size(), torch.Size([3])):
248258
# learned bounds
@@ -341,7 +351,10 @@ def test_normalize(self) -> None:
341351
# test errors on wrong shape
342352
nlz = Normalize(d=2, batch_shape=batch_shape)
343353
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
344-
with self.assertRaises(BotorchTensorDimensionError):
354+
with self.assertRaisesRegex(
355+
BotorchTensorDimensionError,
356+
"Wrong input dimension. Received 1, expected 2.",
357+
):
345358
nlz(X)
346359

347360
# test equals
@@ -403,6 +416,22 @@ def test_normalize(self) -> None:
403416
expected_X = torch.tensor([[1.5, 0.75]], device=self.device, dtype=dtype)
404417
self.assertAllClose(nlzd_X, expected_X)
405418

419+
# Test broadcasting across batch dimensions in eval mode
420+
x = torch.tensor(
421+
[[0.0, 2.0], [3.0, 5.0]], device=self.device, dtype=dtype
422+
).unsqueeze(-1)
423+
self.assertEqual(x.shape, torch.Size([2, 2, 1]))
424+
nlz = Normalize(d=1, batch_shape=torch.Size([2]))
425+
nlz(x)
426+
nlz.eval()
427+
x2 = torch.tensor([[1.0]], device=self.device, dtype=dtype)
428+
nlzd_x2 = nlz.transform(x2)
429+
self.assertEqual(nlzd_x2.shape, torch.Size([2, 1, 1]))
430+
self.assertAllClose(
431+
nlzd_x2.squeeze(),
432+
torch.tensor([0.5, -1.0], dtype=dtype, device=self.device),
433+
)
434+
406435
def test_standardize(self) -> None:
407436
for dtype in (torch.float, torch.double):
408437
# basic init
@@ -459,7 +488,9 @@ def test_standardize(self) -> None:
459488
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
460489
X = X.to(self.device, dtype=dtype)
461490
self.assertEqual(torch.isfinite(stdz(X)).sum(), X.numel())
462-
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
491+
with self.assertRaisesRegex(
492+
BotorchTensorDimensionError, r"must have at least \d+ dim"
493+
):
463494
stdz(torch.randn(X.shape[-1], dtype=dtype))
464495

465496
# basic usage

0 commit comments

Comments
 (0)