From 73e205208c6d6830c9e9c01e042faf60081425f2 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Tue, 1 Oct 2024 10:32:07 -0700 Subject: [PATCH] InputTransfrom list broadcasted over batch shapes (#2558) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2558 This commit adds `BatchBroadcastedInputTransform`, which broadcasts a list of input transforms across the first batch dimension of the input X, thereby enabling batch models in cases where only the input transforms are structurally different for each batch. Reviewed By: Balandat Differential Revision: D63660807 --- botorch/models/transforms/input.py | 126 +++++++++++++++++++++++++- test/models/transforms/test_input.py | 127 +++++++++++++++++++++++++++ 2 files changed, 252 insertions(+), 1 deletion(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 329e407c50..74bbafc191 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union from warnings import warn import numpy as np @@ -155,6 +155,130 @@ def preprocess_transform(self, X: Tensor) -> Tensor: return X +class BatchBroadcastedInputTransform(InputTransform, ModuleDict): + r"""An input transform representing a list of transforms to be broadcasted.""" + + def __init__( + self, + transforms: List[InputTransform], + broadcast_index: int = -3, + ) -> None: + r"""A transform list that is broadcasted across a batch dimension specified by + `broadcast_index`. This is allows using a batched Gaussian process model when + the input transforms are different for different batch dimensions. + + Args: + transforms: The transforms to broadcast across the first batch dimension. + The transform at position i in the list will be applied to `X[i]` for + a given input tensor `X` in the forward pass. + broadcast_index: The tensor index at which the transforms are broadcasted. + + Example: + >>> tf1 = Normalize(d=2) + >>> tf2 = InputStandardize(d=2) + >>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2]) + """ + super().__init__() + self.transform_on_train = False + self.transform_on_eval = False + self.transform_on_fantasize = False + self.transforms = transforms + if broadcast_index in (-2, -1): + raise ValueError( + "The broadcast index cannot be -2 and -1, as these indices are reserved" + " for non-batch, data and input dimensions." + ) + self.broadcast_index = broadcast_index + self.is_one_to_many = self.transforms[0].is_one_to_many + if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms): + raise ValueError( # output shapes of transforms must be the same + "All transforms must have the same is_one_to_many property." + ) + for tf in self.transforms: + self.transform_on_train |= tf.transform_on_train + self.transform_on_eval |= tf.transform_on_eval + self.transform_on_fantasize |= tf.transform_on_fantasize + + def transform(self, X: Tensor) -> Tensor: + r"""Transform the inputs to a model. + + Individual transforms are applied in sequence and results are returned as + a batched tensor. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of transformed inputs. + """ + return torch.stack( + [t.forward(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def untransform(self, X: Tensor) -> Tensor: + r"""Un-transform the inputs to a model. + + Un-transforms of the individual transforms are applied in reverse sequence. + + Args: + X: A `batch_shape x n x d`-dim tensor of transformed inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of un-transformed inputs. + """ + return torch.stack( + [t.untransform(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def equals(self, other: InputTransform) -> bool: + r"""Check if another input transform is equivalent. + + Args: + other: Another input transform. + + Returns: + A boolean indicating if the other transform is equivalent. + """ + return ( + super().equals(other=other) + and all(t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms)) + and (self.broadcast_index == other.broadcast_index) + ) + + def preprocess_transform(self, X: Tensor) -> Tensor: + r"""Apply transforms for preprocessing inputs. + + The main use cases for this method are 1) to preprocess training data + before calling `set_train_data` and 2) preprocess `X_baseline` for noisy + acquisition functions so that `X_baseline` is "preprocessed" with the + same transformations as the cached training inputs. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + A `batch_shape x n x d`-dim tensor of (transformed) inputs. + """ + return torch.stack( + [t.preprocess_transform(Xi) for Xi, t in self._Xs_and_transforms(X)], + dim=self.broadcast_index, + ) + + def _Xs_and_transforms(self, X: Tensor) -> Iterable[Tuple[Tensor, InputTransform]]: + r"""Returns an iterable of sub-tensors of X and their associated transforms. + + Args: + X: A `batch_shape x n x d`-dim tensor of inputs. + + Returns: + An iterable containing tuples of sub-tensors of X and their transforms. + """ + Xs = X.unbind(dim=self.broadcast_index) + return zip(Xs, self.transforms) + + class ChainedInputTransform(InputTransform, ModuleDict): r"""An input transform representing the chaining of individual transforms.""" diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 0d09dab09e..1e17fb64f0 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -14,6 +14,7 @@ from botorch.models.transforms.input import ( AffineInputTransform, AppendFeatures, + BatchBroadcastedInputTransform, ChainedInputTransform, FilterFeatures, InputPerturbation, @@ -652,6 +653,132 @@ def test_chained_input_transform(self) -> None: tf = ChainedInputTransform(stz=tf1, pert=tf2) self.assertTrue(tf.is_one_to_many) + def test_batch_broadcasted_input_transform(self) -> None: + ds = (1, 2) + batch_args = [ + (torch.Size([2]), {}), + (torch.Size([3, 2]), {}), + (torch.Size([2, 3]), {"broadcast_index": 0}), + (torch.Size([5, 2, 3]), {"broadcast_index": 1}), + ] + dtypes = (torch.float, torch.double) + # set seed to range where this is known to not be flaky + torch.manual_seed(randint(0, 1000)) + + for d, (batch_shape, kwargs), dtype in itertools.product( + ds, batch_args, dtypes + ): + bounds = torch.tensor( + [[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype + ) + # when the batch_shape is (2, 3), the transform list is broadcasted across + # the first dimension, whereas each individual transform gets broadcasted + # over the remaining batch dimensions. + if "broadcast_index" not in kwargs: + broadcast_index = -3 + tf_batch_shape = batch_shape[:-1] + else: + broadcast_index = kwargs["broadcast_index"] + # if the broadcast index is negative, we need to adjust the index + # when indexing into the batch shape tuple + i = broadcast_index + 2 if broadcast_index < 0 else broadcast_index + tf_batch_shape = list(batch_shape[:i]) + tf_batch_shape.extend(list(batch_shape[i + 1 :])) + tf_batch_shape = torch.Size(tf_batch_shape) + + tf1 = Normalize(d=d, bounds=bounds, batch_shape=tf_batch_shape) + tf2 = InputStandardize(d=d, batch_shape=tf_batch_shape) + transforms = [tf1, tf2] + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + # make copies for validation below + transforms_ = [deepcopy(tf_i) for tf_i in transforms] + self.assertTrue(tf.training) + # self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"]) + self.assertEqual(tf.transforms[0], tf1) + self.assertEqual(tf.transforms[1], tf2) + self.assertFalse(tf.is_one_to_many) + + X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype) + X_tf = tf(X) + Xs = X.unbind(dim=broadcast_index) + + X_tf_ = torch.stack( + [tf_i_(Xi) for tf_i_, Xi in zip(transforms_, Xs)], dim=broadcast_index + ) + self.assertTrue(tf1.training) + self.assertTrue(tf2.training) + self.assertTrue(torch.equal(X_tf, X_tf_)) + X_utf = tf.untransform(X_tf) + self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4) + + # test not transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = False + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.eval() + self.assertTrue(torch.equal(tf(X), X)) + + # test transformed on eval + for tf_i in transforms: + tf_i.transform_on_eval = True + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.eval() + self.assertTrue(torch.equal(tf(X), X_tf)) + + # test not transformed on train + for tf_i in transforms: + tf_i.transform_on_train = False + + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + tf.train() + self.assertTrue(torch.equal(tf(X), X)) + + # test __eq__ + other_tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + self.assertTrue(tf.equals(other_tf)) + # change order + other_tf = BatchBroadcastedInputTransform( + transforms=list(reversed(transforms)) + ) + self.assertFalse(tf.equals(other_tf)) + # Identical transforms but different objects. + other_tf = BatchBroadcastedInputTransform( + transforms=deepcopy(transforms), **kwargs + ) + self.assertTrue(tf.equals(other_tf)) + + # test preprocess_transform + transforms[-1].transform_on_train = False + transforms[0].transform_on_train = True + tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs) + self.assertTrue( + torch.equal( + tf.preprocess_transform(X).unbind(dim=broadcast_index)[0], + transforms[0].transform(Xs[0]), + ) + ) + + # test one-to-many + tf2 = InputPerturbation(perturbation_set=2 * bounds) + with self.assertRaisesRegex(ValueError, r".*one_to_many.*"): + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) + + # these could technically be batched internally, but we're testing the generic + # batch broadcasted transform list here. Could change test to use AppendFeatures + tf1 = InputPerturbation(perturbation_set=bounds) + tf2 = InputPerturbation(perturbation_set=2 * bounds) + tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs) + self.assertTrue(tf.is_one_to_many) + + with self.assertRaisesRegex( + ValueError, r"The broadcast index cannot be -2 and -1" + ): + tf = BatchBroadcastedInputTransform( + transforms=[tf1, tf2], broadcast_index=-2 + ) + def test_round_transform_init(self) -> None: # basic init int_idcs = [0, 4]