diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 329e407c50..aea2823244 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, List, Optional, Union from warnings import warn import numpy as np @@ -155,6 +155,101 @@ def preprocess_transform(self, X: Tensor) -> Tensor: return X +class BatchBroadcastedInputTransform(InputTransform, ModuleDict): + r"""An input transform representing a list of transforms to be broadcasted.""" + + def __init__( + self, + transforms: List[InputTransform], + ) -> None: + r"""A transform list that is broadcasted across the input's first dimension. + This is allows using a batched Gaussian process model in cases where the input + transforms are different for different batch dimensions. + + Args: + transforms: The transforms to broadcast across the first batch dimension. + The transform at position i in the list will be applied to `X[i]` for + a given input tensor `X` in the forward pass. + + Example: + >>> tf1 = Normalize(d=2) + >>> tf2 = InputStandardize(d=2) + >>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2]) + """ + super().__init__() + self.transform_on_train = False + self.transform_on_eval = False + self.transform_on_fantasize = False + self.transforms = transforms + self.is_one_to_many = self.transforms[0].is_one_to_many + if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms): + raise ValueError( # output shapes of transforms must be the same + "All transforms must have the same is_one_to_many property." + ) + for tf in self.transforms: + self.transform_on_train |= tf.transform_on_train + self.transform_on_eval |= tf.transform_on_eval + self.transform_on_fantasize |= tf.transform_on_fantasize + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the inputs to a model. + + Individual transforms are applied in sequence and results are returned as + a batched tensor. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of transformed inputs. + """ + return torch.stack([tf.forward(Xi) for Xi, tf in zip(X, self.transforms)]) + + def untransform(self, X: Tensor) -> Tensor: + r"""Un-transform the inputs to a model. + + Un-transforms of the individual transforms are applied in reverse sequence. + + Args: + X: A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of un-transformed inputs. + """ + return torch.stack([tf.untransform(Xi) for Xi, tf in zip(X, self.transforms)]) + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return super().equals(other=other) and all( + t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms) + ) + + def preprocess_transform(self, X: Tensor) -> Tensor: + r"""Apply transforms for preprocessing inputs. + + The main use cases for this method are 1) to preprocess training data + before calling `set_train_data` and 2) preprocess `X_baseline` for noisy + acquisition functions so that `X_baseline` is "preprocessed" with the + same transformations as the cached training inputs. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of (transformed) inputs. + """ + return torch.stack( + [tf.preprocess_transform(Xi) for Xi, tf in zip(X, self.transforms)] + ) + + class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms.""" diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 0d09dab09e..1d654ed997 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -14,6 +14,7 @@ from botorch.models.transforms.input import ( AffineInputTransform, AppendFeatures, + BatchBroadcastedInputTransform, ChainedInputTransform, FilterFeatures, InputPerturbation, @@ -652,6 +653,99 @@ def test_chained_input_transform(self) -> None: tf = ChainedInputTransform(stz=tf1, pert=tf2) self.assertTrue(tf.is_one_to_many) + def test_batch_broadcasted_input_transform(self) -> None: + ds = (1, 2) + batch_shapes = [torch.Size([2]), torch.Size([2, 3])] + dtypes = (torch.float, torch.double) + # set seed to range where this is known to not be flaky + torch.manual_seed(randint(0, 1000)) + + for d, batch_shape, dtype in itertools.product(ds, batch_shapes, dtypes): + bounds = torch.tensor( + [[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype + ) + # when the batch_shape is (2, 3), the transform list is broadcasted across + # the first dimension, whereas each individual transform gets broadcasted + # over the remaining batch dimensions. + tf1 = Normalize(d=d, bounds=bounds, batch_shape=batch_shape[1:]) + tf2 = InputStandardize(d=d, batch_shape=batch_shape[1:]) + transforms = [tf1, tf2] + tf = BatchBroadcastedInputTransform(transforms=transforms) + # make copies for validation below + transforms_ = [deepcopy(tf_i) for tf_i in transforms] + self.assertTrue(tf.training) + # self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"]) + self.assertEqual(tf.transforms[0], tf1) + self.assertEqual(tf.transforms[1], tf2) + self.assertFalse(tf.is_one_to_many) + + X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype) + X_tf = tf(X) + X_tf_ = torch.stack([tf_i_(Xi) for tf_i_, Xi in zip(transforms_, X)], dim=0) + self.assertTrue(tf1.training) + self.assertTrue(tf2.training) + self.assertTrue(torch.equal(X_tf, X_tf_)) + X_utf = tf.untransform(X_tf) + self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4) + + # test not transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = False + + tf = BatchBroadcastedInputTransform(transforms=transforms) + tf.eval() + self.assertTrue(torch.equal(tf(X), X)) + + # test transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = True + + tf = BatchBroadcastedInputTransform(transforms=transforms) + tf.eval() + self.assertTrue(torch.equal(tf(X), X_tf)) + + # test not transformed on train + for tf_i in transforms: + tf_i.transform_on_train = False + + tf = BatchBroadcastedInputTransform(transforms=transforms) + tf.train() + self.assertTrue(torch.equal(tf(X), X)) + + # test __eq__ + other_tf = BatchBroadcastedInputTransform(transforms=transforms) + self.assertTrue(tf.equals(other_tf)) + # change order + other_tf = BatchBroadcastedInputTransform( + transforms=list(reversed(transforms)) + ) + self.assertFalse(tf.equals(other_tf)) + # Identical transforms but different objects. + other_tf = BatchBroadcastedInputTransform(transforms=deepcopy(transforms)) + self.assertTrue(tf.equals(other_tf)) + + # test preprocess_transform + transforms[-1].transform_on_train = False + transforms[0].transform_on_train = True + tf = BatchBroadcastedInputTransform(transforms=transforms) + self.assertTrue( + torch.equal( + tf.preprocess_transform(X)[0], transforms[0].transform(X[0]) + ) + ) + + # test one-to-many + tf2 = InputPerturbation(perturbation_set=2 * bounds) + with self.assertRaisesRegex(ValueError, r".*one_to_many.*"): + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2]) + + # these could technically be batched internally, but we're testing the generic + # batch broadcasted transform list here. Could change test to use AppendFeatures + tf1 = InputPerturbation(perturbation_set=bounds) + tf2 = InputPerturbation(perturbation_set=2 * bounds) + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2]) + self.assertTrue(tf.is_one_to_many) + def test_round_transform_init(self) -> None: # basic init int_idcs = [0, 4]