From ee91b55efc21d2c91c52ad25841e1e46ea97b55d Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Wed, 8 Feb 2023 10:54:46 -0800 Subject: [PATCH] acquisition function wrapper (#1532) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/1532 Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs Differential Revision: https://internalfb.com/D41629186 fbshipit-source-id: d39c44320682133dd09956d8eea9d41196652198 --- botorch/acquisition/fixed_feature.py | 26 ++++-------- botorch/acquisition/penalized.py | 24 +++-------- botorch/acquisition/proximal.py | 15 ++++--- botorch/acquisition/wrapper.py | 55 ++++++++++++++++++++++++++ sphinx/source/acquisition.rst | 9 ++++- test/acquisition/test_fixed_feature.py | 2 +- test/acquisition/test_proximal.py | 8 +++- test/acquisition/test_wrapper.py | 52 ++++++++++++++++++++++++ 8 files changed, 144 insertions(+), 47 deletions(-) create mode 100644 botorch/acquisition/wrapper.py create mode 100644 test/acquisition/test_wrapper.py diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..763226799e 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -16,11 +16,11 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor -from torch.nn import Module -class FixedFeatureAcquisitionFunction(AcquisitionFunction): +class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AquisitionFunctions to fix a subset of features. Example: @@ -56,8 +56,7 @@ def __init__( combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ - Module.__init__(self) - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) dtype = torch.float device = torch.device("cpu") self.d = d @@ -126,24 +125,13 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Optional[Tensor]): + def set_X_pending(self, X_pending: Optional[Tensor]): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: - self.acq_func.X_pending = self._construct_X_full(X_pending) + full_X_pending = self._construct_X_full(X_pending) else: - self.acq_func.X_pending = X_pending + full_X_pending = None + self.acq_func.set_X_pending(full_X_pending) def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..9ee8f1fee5 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -15,9 +15,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective -from botorch.exceptions import UnsupportedError +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term -class PenalizedAcquisitionFunction(AcquisitionFunction): +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: @@ -161,29 +160,16 @@ def __init__( penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ - super().__init__(model=raw_acqf.model) - self.raw_acqf = raw_acqf + AcquisitionFunction.__init__(self, model=raw_acqf.model) + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter def forward(self, X: Tensor) -> Tensor: - raw_value = self.raw_acqf(X=X) + raw_value = self.acq_func(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Optional[Tensor]: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. diff --git a/botorch/acquisition/proximal.py b/botorch/acquisition/proximal.py index 9cd4aed7ad..b1d68edef1 100644 --- a/botorch/acquisition/proximal.py +++ b/botorch/acquisition/proximal.py @@ -15,6 +15,8 @@ import torch from botorch.acquisition import AcquisitionFunction + +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models import ModelListGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -25,7 +27,7 @@ from torch.nn import Module -class ProximalAcquisitionFunction(AcquisitionFunction): +class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AcquisitionFunctions to add proximal weighting of the acquisition function. The acquisition function is weighted via a squared exponential centered at the last training point, @@ -70,9 +72,7 @@ def __init__( beta: If not None, apply a softplus transform to the base acquisition function, allows negative base acquisition function values. """ - Module.__init__(self) - - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) model = self.acq_func.model if hasattr(acq_function, "X_pending"): @@ -80,7 +80,6 @@ def __init__( raise UnsupportedError( "Proximal acquisition function requires `X_pending` to be None." ) - self.X_pending = acq_function.X_pending self.register_buffer("proximal_weights", proximal_weights) self.register_buffer( @@ -91,6 +90,12 @@ def __init__( _validate_model(model, proximal_weights) + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + raise UnsupportedError( + "Proximal acquisition function does not support `X_pending`." + ) + @t_batch_mode_transform(expected_q=1, assert_output_shape=False) def forward(self, X: Tensor) -> Tensor: r"""Evaluate base acquisition function with proximal weighting. diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..08dfbd2849 --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor +from torch.nn import Module + + +class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.acq_func = acq_function + + @property + def X_pending(self) -> Optional[Tensor]: + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.set_X_pending(X_pending) + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the wrapped acquisition function on the candidate set X. + + Args: + X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim + design points each. + + Returns: + A `(b)`-dim Tensor of acquisition function values at the given + design points `X`. + """ + pass # pragma: no cover diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..a3c5eaeb5a 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -21,6 +21,11 @@ Analytic Acquisition Function API .. autoclass:: AnalyticAcquisitionFunction :members: +Acquisition Function Wrapper API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.wrapper + :members: + Cached Cholesky Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.cached_cholesky @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..b8f570e7e1 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -87,7 +87,7 @@ def test_fixed_features(self): qEI_ff.set_X_pending(X_pending[..., :-1]) self.assertAllClose(qEI.X_pending, X_pending) # test setting to None - qEI_ff.X_pending = None + qEI_ff.set_X_pending(None) self.assertIsNone(qEI_ff.X_pending) # test gradient diff --git a/test/acquisition/test_proximal.py b/test/acquisition/test_proximal.py index 795daa1b34..e17536ddd0 100644 --- a/test/acquisition/test_proximal.py +++ b/test/acquisition/test_proximal.py @@ -209,9 +209,15 @@ def test_proximal(self): # test for x_pending points pending_acq = DummyAcquisitionFunction(model) - pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype)) + X_pending = torch.rand(3, 3, device=self.device, dtype=dtype) + pending_acq.set_X_pending(X_pending) with self.assertRaises(UnsupportedError): ProximalAcquisitionFunction(pending_acq, proximal_weights) + # test setting pending points + pending_acq.set_X_pending(None) + af = ProximalAcquisitionFunction(pending_acq, proximal_weights) + with self.assertRaises(UnsupportedError): + af.set_X_pending(X_pending) # test model with multi-batch training inputs train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype) diff --git a/test/acquisition/test_wrapper.py b/test/acquisition/test_wrapper.py new file mode 100644 index 0000000000..e35175fb9b --- /dev/null +++ b/test/acquisition/test_wrapper.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.exceptions.errors import UnsupportedError +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyWrapper(AbstractAcquisitionFunctionWrapper): + def forward(self, X): + return self.acq_func(X) + + +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): + def test_abstract_acquisition_function_wrapper(self): + for dtype in (torch.float, torch.double): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, dtype=dtype, device=self.device), + variance=torch.ones(1, 1, dtype=dtype, device=self.device), + ) + ) + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIs(wrapped_af.acq_func, acq_func) + # test forward + X = torch.rand(1, 1, dtype=dtype, device=self.device) + with torch.no_grad(): + wrapped_val = wrapped_af(X) + af_val = acq_func(X) + self.assertEqual(wrapped_val.item(), af_val.item()) + + # test X_pending + with self.assertRaises(ValueError): + self.assertIsNone(wrapped_af.X_pending) + with self.assertRaises(UnsupportedError): + wrapped_af.set_X_pending(X) + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIsNone(wrapped_af.X_pending) + wrapped_af.set_X_pending(X) + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) + self.assertTrue(torch.equal(X, acq_func.X_pending)) + wrapped_af.set_X_pending(None) + self.assertIsNone(wrapped_af.X_pending) + self.assertIsNone(acq_func.X_pending)