From 9c7e4714c41036e9888b6c337738f98c8218a5e0 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 2 Oct 2024 06:53:06 -0700 Subject: [PATCH] InteractionFeatures input transform (#2560) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2560 InteractionFeatures input transform to compute first-order interactions between inputs. Used for feature importance work in conjunction with (warped) linear models. Reviewed By: sdaulton Differential Revision: D63673008 --- botorch/models/transforms/input.py | 26 +++++++++++++++++- botorch/models/transforms/utils.py | 15 +++++++++++ test/models/transforms/test_input.py | 40 ++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) diff --git a/botorch/models/transforms/input.py b/botorch/models/transforms/input.py index 74bbafc191..9ca8e24f59 100644 --- a/botorch/models/transforms/input.py +++ b/botorch/models/transforms/input.py @@ -24,7 +24,7 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.exceptions.warnings import UserInputWarning -from botorch.models.transforms.utils import subset_transform +from botorch.models.transforms.utils import interaction_features, subset_transform from botorch.models.utils import fantasize from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE from gpytorch import Module as GPyTorchModule @@ -1370,6 +1370,30 @@ def transform(self, X: Tensor) -> Tensor: return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1]) +class InteractionFeatures(AppendFeatures): + r"""A transform that appends the first-order interaction terms $x_i * x_j, i < j$, + for all or a subset of the input variables.""" + + def __init__( + self, + indices: Optional[list[int]] = None, + ) -> None: + r"""Initializes the InteractionFeatures transform. + + Args: + indices: Indices of the subset of dimensions to compute interaction + features on. + """ + + super().__init__( + f=interaction_features, + indices=indices, + transform_on_train=True, + transform_on_eval=True, + transform_on_fantasize=True, + ) + + class FilterFeatures(InputTransform, Module): r"""A transform that filters the input with a given set of features indices. diff --git a/botorch/models/transforms/utils.py b/botorch/models/transforms/utils.py index e8bda88625..17901d2efb 100644 --- a/botorch/models/transforms/utils.py +++ b/botorch/models/transforms/utils.py @@ -126,3 +126,18 @@ def f(self, X: Tensor) -> Tensor: return Y return f + + +def interaction_features(X: Tensor) -> Tensor: + """Computes the interaction features between the inputs. + + Args: + X: A `batch_shape x q x d`-dim tensor of inputs. + indices: The input dimensions to generate interaction features for. + + Returns: + A `n x q x 1 x (d * (d-1) / 2))`-dim tensor of interaction features. + """ + dim = X.shape[-1] + row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1) + return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2) diff --git a/test/models/transforms/test_input.py b/test/models/transforms/test_input.py index 1e17fb64f0..b3b20fa025 100644 --- a/test/models/transforms/test_input.py +++ b/test/models/transforms/test_input.py @@ -20,6 +20,7 @@ InputPerturbation, InputStandardize, InputTransform, + InteractionFeatures, Log10, Normalize, OneHotToNumeric, @@ -1629,6 +1630,45 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor: self.assertEqual(X_transformed.shape, torch.Size((10, 4))) +class TestInteractionFeatures(BotorchTestCase): + def test_interaction_features(self) -> None: + interaction = InteractionFeatures() + X = torch.arange(6, dtype=torch.float).reshape(2, 3) + X_tf = interaction(X) + self.assertTrue(X_tf.shape, torch.Size([2, 6])) + + # test correct output values + self.assertTrue( + torch.equal( + X_tf, + torch.tensor( + [[0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [3.0, 4.0, 5.0, 12.0, 15.0, 20.0]] + ), + ) + ) + X = torch.arange(6, dtype=torch.float).reshape(2, 3) + interaction = InteractionFeatures(indices=[1, 2]) + X_tf = interaction(X) + self.assertTrue( + torch.equal( + X_tf, + torch.tensor([[0.0, 1.0, 2.0, 2.0], [3.0, 4.0, 5.0, 20.0]]), + ) + ) + with self.assertRaisesRegex( + IndexError, "index 2 is out of bounds for dimension 0 with size 2" + ): + interaction(torch.rand(4, 2)) + + # test batched evaluation + interaction = InteractionFeatures() + X_tf = interaction(torch.rand(4, 2, 4)) + self.assertTrue(X_tf.shape, torch.Size([4, 2, 10])) + + X_tf = interaction(torch.rand(5, 7, 3, 4)) + self.assertTrue(X_tf.shape, torch.Size([5, 7, 3, 10])) + + class TestFilterFeatures(BotorchTestCase): def test_filter_features(self) -> None: with self.assertRaises(ValueError):