Skip to content

Commit

Permalink
InputTransfrom list broadcasted over batch shapes (#2558)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2558

This commit adds `BatchBroadcastedInputTransform`, which broadcasts a list of input transforms across the first batch dimension of the input X, thereby enabling batch models in cases where only the input transforms are structurally different for each batch.

Differential Revision: D63660807
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Oct 1, 2024
1 parent e29e30a commit 002deae
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 1 deletion.
126 changes: 125 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -155,6 +155,130 @@ def preprocess_transform(self, X: Tensor) -> Tensor:
return X


class BatchBroadcastedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing a list of transforms to be broadcasted."""

def __init__(
self,
transforms: List[InputTransform],
broadcast_index: int = -3,
) -> None:
r"""A transform list that is broadcasted across a batch dimension specified by
`broadcast_index`. This is allows using a batched Gaussian process model when
the input transforms are different for different batch dimensions.
Args:
transforms: The transforms to broadcast across the first batch dimension.
The transform at position i in the list will be applied to `X[i]` for
a given input tensor `X` in the forward pass.
broadcast_index: The tensor index at which the transforms are broadcasted.
Example:
>>> tf1 = Normalize(d=2)
>>> tf2 = InputStandardize(d=2)
>>> tf = BatchBroadcastedTransformList(transforms=[tf1, tf2])
"""
super().__init__()
self.transform_on_train = False
self.transform_on_eval = False
self.transform_on_fantasize = False
self.transforms = transforms
if broadcast_index in (-2, -1):
raise ValueError(
"The broadcast index cannot be -2 and -1, as these indices are reserved"
" for non-batch, data and input dimensions."
)
self.broadcast_index = broadcast_index
self.is_one_to_many = self.transforms[0].is_one_to_many
if not all(tf.is_one_to_many == self.is_one_to_many for tf in self.transforms):
raise ValueError( # output shapes of transforms must be the same
"All transforms must have the same is_one_to_many property."
)
for tf in self.transforms:
self.transform_on_train |= tf.transform_on_train
self.transform_on_eval |= tf.transform_on_eval
self.transform_on_fantasize |= tf.transform_on_fantasize

def transform(self, X: Tensor) -> Tensor:
r"""Transform the inputs to a model.
Individual transforms are applied in sequence and results are returned as
a batched tensor.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of transformed inputs.
"""
return torch.stack(
[t.forward(Xi) for Xi, t in self._Xs_and_transforms(X)],
dim=self.broadcast_index,
)

def untransform(self, X: Tensor) -> Tensor:
r"""Un-transform the inputs to a model.
Un-transforms of the individual transforms are applied in reverse sequence.
Args:
X: A `batch_shape x n x d`-dim tensor of transformed inputs.
Returns:
A `batch_shape x n x d`-dim tensor of un-transformed inputs.
"""
return torch.stack(
[t.untransform(Xi) for Xi, t in self._Xs_and_transforms(X)],
dim=self.broadcast_index,
)

def equals(self, other: InputTransform) -> bool:
r"""Check if another input transform is equivalent.
Args:
other: Another input transform.
Returns:
A boolean indicating if the other transform is equivalent.
"""
return (
super().equals(other=other)
and all(t1.equals(t2) for t1, t2 in zip(self.transforms, other.transforms))
and (self.broadcast_index == other.broadcast_index)
)

def preprocess_transform(self, X: Tensor) -> Tensor:
r"""Apply transforms for preprocessing inputs.
The main use cases for this method are 1) to preprocess training data
before calling `set_train_data` and 2) preprocess `X_baseline` for noisy
acquisition functions so that `X_baseline` is "preprocessed" with the
same transformations as the cached training inputs.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A `batch_shape x n x d`-dim tensor of (transformed) inputs.
"""
return torch.stack(
[t.preprocess_transform(Xi) for Xi, t in self._Xs_and_transforms(X)],
dim=self.broadcast_index,
)

def _Xs_and_transforms(self, X: Tensor) -> Iterable[Tuple[Tensor, InputTransform]]:
r"""Returns a list of sub-tensors of X and their associated input transforms.
Args:
X: A `batch_shape x n x d`-dim tensor of inputs.
Returns:
A iterable containing tuples of sub-tensors of X and their input transforms.
"""
Xs = X.unbind(dim=self.broadcast_index)
return zip(Xs, self.transforms)


class ChainedInputTransform(InputTransform, ModuleDict):
r"""An input transform representing the chaining of individual transforms."""

Expand Down
127 changes: 127 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from botorch.models.transforms.input import (
AffineInputTransform,
AppendFeatures,
BatchBroadcastedInputTransform,
ChainedInputTransform,
FilterFeatures,
InputPerturbation,
Expand Down Expand Up @@ -652,6 +653,132 @@ def test_chained_input_transform(self) -> None:
tf = ChainedInputTransform(stz=tf1, pert=tf2)
self.assertTrue(tf.is_one_to_many)

def test_batch_broadcasted_input_transform(self) -> None:
ds = (1, 2)
batch_args = [
(torch.Size([2]), {}),
(torch.Size([3, 2]), {}),
(torch.Size([2, 3]), {"broadcast_index": 0}),
(torch.Size([5, 2, 3]), {"broadcast_index": 1}),
]
dtypes = (torch.float, torch.double)
# set seed to range where this is known to not be flaky
torch.manual_seed(randint(0, 1000))

for d, (batch_shape, kwargs), dtype in itertools.product(
ds, batch_args, dtypes
):
bounds = torch.tensor(
[[-2.0] * d, [2.0] * d], device=self.device, dtype=dtype
)
# when the batch_shape is (2, 3), the transform list is broadcasted across
# the first dimension, whereas each individual transform gets broadcasted
# over the remaining batch dimensions.
if "broadcast_index" not in kwargs:
broadcast_index = -3
tf_batch_shape = batch_shape[:-1]
else:
broadcast_index = kwargs["broadcast_index"]
# if the broadcast index is negative, we need to adjust the index
# when indexing into the batch shape tuple
i = broadcast_index + 2 if broadcast_index < 0 else broadcast_index
tf_batch_shape = list(batch_shape[:i])
tf_batch_shape.extend(list(batch_shape[i + 1 :]))
tf_batch_shape = torch.Size(tf_batch_shape)

tf1 = Normalize(d=d, bounds=bounds, batch_shape=tf_batch_shape)
tf2 = InputStandardize(d=d, batch_shape=tf_batch_shape)
transforms = [tf1, tf2]
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
# make copies for validation below
transforms_ = [deepcopy(tf_i) for tf_i in transforms]
self.assertTrue(tf.training)
# self.assertEqual(sorted(tf.keys()), ["stz_fixed", "stz_learned"])
self.assertEqual(tf.transforms[0], tf1)
self.assertEqual(tf.transforms[1], tf2)
self.assertFalse(tf.is_one_to_many)

X = torch.rand(*batch_shape, 4, d, device=self.device, dtype=dtype)
X_tf = tf(X)
Xs = X.unbind(dim=broadcast_index)

X_tf_ = torch.stack(
[tf_i_(Xi) for tf_i_, Xi in zip(transforms_, Xs)], dim=broadcast_index
)
self.assertTrue(tf1.training)
self.assertTrue(tf2.training)
self.assertTrue(torch.equal(X_tf, X_tf_))
X_utf = tf.untransform(X_tf)
self.assertAllClose(X_utf, X, atol=1e-4, rtol=1e-4)

# test not transformed on eval
for tf_i in transforms:
tf_i.transform_on_eval = False

tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
tf.eval()
self.assertTrue(torch.equal(tf(X), X))

# test transformed on eval
for tf_i in transforms:
tf_i.transform_on_eval = True

tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
tf.eval()
self.assertTrue(torch.equal(tf(X), X_tf))

# test not transformed on train
for tf_i in transforms:
tf_i.transform_on_train = False

tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
tf.train()
self.assertTrue(torch.equal(tf(X), X))

# test __eq__
other_tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
self.assertTrue(tf.equals(other_tf))
# change order
other_tf = BatchBroadcastedInputTransform(
transforms=list(reversed(transforms))
)
self.assertFalse(tf.equals(other_tf))
# Identical transforms but different objects.
other_tf = BatchBroadcastedInputTransform(
transforms=deepcopy(transforms), **kwargs
)
self.assertTrue(tf.equals(other_tf))

# test preprocess_transform
transforms[-1].transform_on_train = False
transforms[0].transform_on_train = True
tf = BatchBroadcastedInputTransform(transforms=transforms, **kwargs)
self.assertTrue(
torch.equal(
tf.preprocess_transform(X).unbind(dim=broadcast_index)[0],
transforms[0].transform(Xs[0]),
)
)

# test one-to-many
tf2 = InputPerturbation(perturbation_set=2 * bounds)
with self.assertRaisesRegex(ValueError, r".*one_to_many.*"):
tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs)

# these could technically be batched internally, but we're testing the generic
# batch broadcasted transform list here. Could change test to use AppendFeatures
tf1 = InputPerturbation(perturbation_set=bounds)
tf2 = InputPerturbation(perturbation_set=2 * bounds)
tf = BatchBroadcastedInputTransform(transforms=[tf1, tf2], **kwargs)
self.assertTrue(tf.is_one_to_many)

with self.assertRaisesRegex(
ValueError, r"The broadcast index cannot be -2 and -1"
):
tf = BatchBroadcastedInputTransform(
transforms=[tf1, tf2], broadcast_index=-2
)

def test_round_transform_init(self) -> None:
# basic init
int_idcs = [0, 4]
Expand Down

0 comments on commit 002deae

Please sign in to comment.