diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 329e407c50..4cd397577e 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union from warnings import warn import numpy as np @@ -155,6 +155,130 @@ def preprocess_transform(self, X: Tensor) -> Tensor: return X +class BatchBroadcastedInputTransform(InputTransform, ModuleDict): + r"""An input transform representing a list of transforms to be broadcasted.""" + + def __init__( + self, + transforms: List[InputTransform], + broadcast_index: int = -3, + ) -> None: + r"""A transform list that is broadcasted across a batch dimension specified by + `broadcast_index`. This is allows using a batched Gaussian process model when + the input transforms are different for different batch dimensions. + + Args: + transforms: The transforms to broadcast across the first batch dimension. + The transform at position i in the list will be applied to `X[i]` for + a given input tensor `X` in the forward pass. + broadcast_index: The tensor index at which the transforms are broadcasted. + + Example: + >>> tf1 = Normalize(d=2) + >>> tf2 = InputStandardize(d=2) + >>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2]) + """ + super().__init__() + self.transform_on_train = False + self.transform_on_eval = False + self.transform_on_fantasize = False + self.transforms = transforms + if broadcast_index in (-2, -1): + raise ValueError( + "The broadcast index cannot be -2 and -1, as these indices are reserved" + " for non-batch, data and input dimensions." + ) + self.broadcast_index = broadcast_index + self.is_one_to_many = self.transforms[0].is_one_to_many + if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms): + raise ValueError( # output shapes of transforms must be the same + "All transforms must have the same is_one_to_many property." + ) + for tf in self.transforms: + self.transform_on_train |= tf.transform_on_train + self.transform_on_eval |= tf.transform_on_eval + self.transform_on_fantasize |= tf.transform_on_fantasize + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the inputs to a model. + + Individual transforms are applied in sequence and results are returned as + a batched tensor. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of transformed inputs. + """ + return torch.stack( + [t.forward(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def untransform(self, X: Tensor) -> Tensor: + r"""Un-transform the inputs to a model. + + Un-transforms of the individual transforms are applied in reverse sequence. + + Args: + X: A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of un-transformed inputs. + """ + return torch.stack( + [t.untransform(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and all(t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms)) + and (self.broadcast_index == other.broadcast_index) + ) + + def preprocess_transform(self, X: Tensor) -> Tensor: + r"""Apply transforms for preprocessing inputs. + + The main use cases for this method are 1) to preprocess training data + before calling `set_train_data` and 2) preprocess `X_baseline` for noisy + acquisition functions so that `X_baseline` is "preprocessed" with the + same transformations as the cached training inputs. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of (transformed) inputs. + """ + return torch.stack( + [t.preprocess_transform(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def _Xs_and_transforms(self, X: Tensor) -> Iterable[Tuple[Tensor, InputTransform]]: + r"""Returns a list of sub-tensors of X and their associated input transforms. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A iterable containing tuples of sub-tensors of X and their input transforms. + """ + Xs = X.unbind(dim=self.broadcast_index) + return zip(Xs, self.transforms) + + class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms.""" diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 0d09dab09e..1e17fb64f0 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -14,6 +14,7 @@ from botorch.models.transforms.input import ( AffineInputTransform, AppendFeatures, + BatchBroadcastedInputTransform, ChainedInputTransform, FilterFeatures, InputPerturbation, @@ -652,6 +653,132 @@ def test_chained_input_transform(self) -> None: tf = ChainedInputTransform(stz=tf1, pert=tf2) self.assertTrue(tf.is_one_to_many) + def test_batch_broadcasted_input_transform(self) -> None: + ds = (1, 2) + batch_args = [ + (torch.Size([2]), {}), + (torch.Size([3, 2]), {}), + (torch.Size([2, 3]), {"broadcast_index": 0}), + (torch.Size([5, 2, 3]), {"broadcast_index": 1}), + ] + dtypes = (torch.float, torch.double) + # set seed to range where this is known to not be flaky + torch.manual_seed(randint(0, 1000)) + + for d, (batch_shape, kwargs), dtype in itertools.product( + ds, batch_args, dtypes + ): + bounds = torch.tensor( + [[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype + ) + # when the batch_shape is (2, 3), the transform list is broadcasted across + # the first dimension, whereas each individual transform gets broadcasted + # over the remaining batch dimensions. + if "broadcast_index" not in kwargs: + broadcast_index = -3 + tf_batch_shape = batch_shape[:-1] + else: + broadcast_index = kwargs["broadcast_index"] + # if the broadcast index is negative, we need to adjust the index + # when indexing into the batch shape tuple + i = broadcast_index + 2 if broadcast_index < 0 else broadcast_index + tf_batch_shape = list(batch_shape[:i]) + tf_batch_shape.extend(list(batch_shape[i + 1 :])) + tf_batch_shape = torch.Size(tf_batch_shape) + + tf1 = Normalize(d=d, bounds=bounds, batch_shape=tf_batch_shape) + tf2 = InputStandardize(d=d, batch_shape=tf_batch_shape) + transforms = [tf1, tf2] + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + # make copies for validation below + transforms_ = [deepcopy(tf_i) for tf_i in transforms] + self.assertTrue(tf.training) + # self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"]) + self.assertEqual(tf.transforms[0], tf1) + self.assertEqual(tf.transforms[1], tf2) + self.assertFalse(tf.is_one_to_many) + + X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype) + X_tf = tf(X) + Xs = X.unbind(dim=broadcast_index) + + X_tf_ = torch.stack( + [tf_i_(Xi) for tf_i_, Xi in zip(transforms_, Xs)], dim=broadcast_index + ) + self.assertTrue(tf1.training) + self.assertTrue(tf2.training) + self.assertTrue(torch.equal(X_tf, X_tf_)) + X_utf = tf.untransform(X_tf) + self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4) + + # test not transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = False + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.eval() + self.assertTrue(torch.equal(tf(X), X)) + + # test transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = True + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.eval() + self.assertTrue(torch.equal(tf(X), X_tf)) + + # test not transformed on train + for tf_i in transforms: + tf_i.transform_on_train = False + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.train() + self.assertTrue(torch.equal(tf(X), X)) + + # test __eq__ + other_tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + self.assertTrue(tf.equals(other_tf)) + # change order + other_tf = BatchBroadcastedInputTransform( + transforms=list(reversed(transforms)) + ) + self.assertFalse(tf.equals(other_tf)) + # Identical transforms but different objects. + other_tf = BatchBroadcastedInputTransform( + transforms=deepcopy(transforms), **kwargs + ) + self.assertTrue(tf.equals(other_tf)) + + # test preprocess_transform + transforms[-1].transform_on_train = False + transforms[0].transform_on_train = True + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + self.assertTrue( + torch.equal( + tf.preprocess_transform(X).unbind(dim=broadcast_index)[0], + transforms[0].transform(Xs[0]), + ) + ) + + # test one-to-many + tf2 = InputPerturbation(perturbation_set=2 * bounds) + with self.assertRaisesRegex(ValueError, r".*one_to_many.*"): + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) + + # these could technically be batched internally, but we're testing the generic + # batch broadcasted transform list here. Could change test to use AppendFeatures + tf1 = InputPerturbation(perturbation_set=bounds) + tf2 = InputPerturbation(perturbation_set=2 * bounds) + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) + self.assertTrue(tf.is_one_to_many) + + with self.assertRaisesRegex( + ValueError, r"The broadcast index cannot be -2 and -1" + ): + tf = BatchBroadcastedInputTransform( + transforms=[tf1, tf2], broadcast_index=-2 + ) + def test_round_transform_init(self) -> None: # basic init int_idcs = [0, 4]