Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InteractionFeatures input transform #2560

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.transforms.utils import subset_transform
from botorch.models.transforms.utils import interaction_features, subset_transform
from botorch.models.utils import fantasize
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
from gpytorch import Module as GPyTorchModule
Expand Down Expand Up @@ -1370,6 +1370,30 @@ def transform(self, X: Tensor) -> Tensor:
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])


class InteractionFeatures(AppendFeatures):
r"""A transform that appends the first-order interaction terms $x_i * x_j, i < j$,
for all or a subset of the input variables."""

def __init__(
self,
indices: Optional[list[int]] = None,
) -> None:
r"""Initializes the InteractionFeatures transform.

Args:
indices: Indices of the subset of dimensions to compute interaction
features on.
"""

super().__init__(
f=interaction_features,
indices=indices,
transform_on_train=True,
transform_on_eval=True,
transform_on_fantasize=True,
)


class FilterFeatures(InputTransform, Module):
r"""A transform that filters the input with a given set of features indices.

Expand Down
15 changes: 15 additions & 0 deletions botorch/models/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,18 @@ def f(self, X: Tensor) -> Tensor:
return Y

return f


def interaction_features(X: Tensor) -> Tensor:
"""Computes the interaction features between the inputs.

Args:
X: A `batch_shape x q x d`-dim tensor of inputs.
indices: The input dimensions to generate interaction features for.

Returns:
A `n x q x 1 x (d * (d-1) / 2))`-dim tensor of interaction features.
"""
dim = X.shape[-1]
row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)
40 changes: 40 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
InputPerturbation,
InputStandardize,
InputTransform,
InteractionFeatures,
Log10,
Normalize,
OneHotToNumeric,
Expand Down Expand Up @@ -1629,6 +1630,45 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))


class TestInteractionFeatures(BotorchTestCase):
def test_interaction_features(self) -> None:
interaction = InteractionFeatures()
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
X_tf = interaction(X)
self.assertTrue(X_tf.shape, torch.Size([2, 6]))

# test correct output values
self.assertTrue(
torch.equal(
X_tf,
torch.tensor(
[[0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [3.0, 4.0, 5.0, 12.0, 15.0, 20.0]]
),
)
)
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
interaction = InteractionFeatures(indices=[1, 2])
X_tf = interaction(X)
self.assertTrue(
torch.equal(
X_tf,
torch.tensor([[0.0, 1.0, 2.0, 2.0], [3.0, 4.0, 5.0, 20.0]]),
)
)
with self.assertRaisesRegex(
IndexError, "index 2 is out of bounds for dimension 0 with size 2"
):
interaction(torch.rand(4, 2))

# test batched evaluation
interaction = InteractionFeatures()
X_tf = interaction(torch.rand(4, 2, 4))
self.assertTrue(X_tf.shape, torch.Size([4, 2, 10]))

X_tf = interaction(torch.rand(5, 7, 3, 4))
self.assertTrue(X_tf.shape, torch.Size([5, 7, 3, 10]))


class TestFilterFeatures(BotorchTestCase):
def test_filter_features(self) -> None:
with self.assertRaises(ValueError):
Expand Down
Loading