Skip to content

Commit

Permalink
Allow broadcasting across dimensions in eval mode; always require X t…
Browse files Browse the repository at this point in the history
…o be at least 2d

Summary:
Context:

A discussion on allowable shapes for transforms concluded:
* We should not allow for broadcasting across the -1 dimension, so the first check in _check_shape should always happen.
* The shapes always need to be broadcastable, so the torch.broadcast_shapes check in _check_shape should always happen.
* We want to allow for broadcasting across the batch dimension in eval model, so the check that X has dimension of at least len(batch_shape) + 2 should only happen in training mode.
* For clarity, we should disallow 1d X, even if broadcastable. BoTorch tends to be strict about requiring explicit dimensions, e.g. GPyTorchModel._validate_tensor_args, and that's a good thing because confusion about tensor dimensions causes a lot of pain.

This diff:
* Only checks that X has number of dimensions equal to 2 + the number of batch dimensions in training mode.
* Disallows <2d X.

Differential Revision: D62404492
  • Loading branch information
esantorella authored and facebook-github-bot committed Sep 9, 2024
1 parent 33e11f4 commit 1618797
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
6 changes: 5 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,9 +462,13 @@ def _check_shape(self, X: Tensor) -> None:
f"Wrong input dimension. Received {X.size(-1)}, "
f"expected {self.offset.size(-1)}."
)
if X.ndim < 2:
raise BotorchTensorDimensionError(
f"`X` must have at least 2 dimensions, but has {X.ndim}."
)

n = len(self.batch_shape) + 2
if X.ndim < n:
if self.training and X.ndim < n:
raise ValueError(
f"`X` must have at least {n} dimensions, {n - 2} batch and 2 innate"
f" , but has {X.ndim}."
Expand Down
37 changes: 34 additions & 3 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,19 @@ def test_normalize(self) -> None:
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
X = X.to(self.device)
self.assertEqual(torch.isfinite(nlz(X)).sum(), X.numel())
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
with self.assertRaisesRegex(
BotorchTensorDimensionError, r"must have at least 2 dimensions"
):
nlz(torch.randn(X.shape[-1], dtype=dtype))

# using unbatched X to train batched transform
nlz = Normalize(d=2, min_range=1e-4, batch_shape=torch.Size([3]))
X = torch.rand(4, 2)
with self.assertRaisesRegex(
ValueError, "must have at least 3 dimensions, 1 batch and 2 innate"
):
nlz(X)

# basic usage
for batch_shape in (torch.Size(), torch.Size([3])):
# learned bounds
Expand Down Expand Up @@ -341,7 +351,10 @@ def test_normalize(self) -> None:
# test errors on wrong shape
nlz = Normalize(d=2, batch_shape=batch_shape)
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
with self.assertRaises(BotorchTensorDimensionError):
with self.assertRaisesRegex(
BotorchTensorDimensionError,
"Wrong input dimension. Received 1, expected 2.",
):
nlz(X)

# test equals
Expand Down Expand Up @@ -403,6 +416,22 @@ def test_normalize(self) -> None:
expected_X = torch.tensor([[1.5, 0.75]], device=self.device, dtype=dtype)
self.assertAllClose(nlzd_X, expected_X)

# Test broadcasting across batch dimensions in eval mode
x = torch.tensor(
[[0.0, 2.0], [3.0, 5.0]], device=self.device, dtype=dtype
).unsqueeze(-1)
self.assertEqual(x.shape, torch.Size([2, 2, 1]))
nlz = Normalize(d=1, batch_shape=torch.Size([2]))
nlz(x)
nlz.eval()
x2 = torch.tensor([[1.0]], device=self.device, dtype=dtype)
nlzd_x2 = nlz.transform(x2)
self.assertEqual(nlzd_x2.shape, torch.Size([2, 1, 1]))
self.assertAllClose(
nlzd_x2.squeeze(),
torch.tensor([0.5, -1.0], dtype=dtype, device=self.device),
)

def test_standardize(self) -> None:
for dtype in (torch.float, torch.double):
# basic init
Expand Down Expand Up @@ -459,7 +488,9 @@ def test_standardize(self) -> None:
X = torch.cat((torch.randn(4, 1), torch.zeros(4, 1)), dim=-1)
X = X.to(self.device, dtype=dtype)
self.assertEqual(torch.isfinite(stdz(X)).sum(), X.numel())
with self.assertRaisesRegex(ValueError, r"must have at least \d+ dim"):
with self.assertRaisesRegex(
BotorchTensorDimensionError, r"must have at least \d+ dim"
):
stdz(torch.randn(X.shape[-1], dtype=dtype))

# basic usage
Expand Down

0 comments on commit 1618797

Please sign in to comment.