diff --git a/botorch/acquisition/fixed_feature.py b/botorch/acquisition/fixed_feature.py index 0f3b85faa7..763226799e 100644 --- a/botorch/acquisition/fixed_feature.py +++ b/botorch/acquisition/fixed_feature.py @@ -16,11 +16,11 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor -from torch.nn import Module -class FixedFeatureAcquisitionFunction(AcquisitionFunction): +class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AquisitionFunctions to fix a subset of features. Example: @@ -56,8 +56,7 @@ def __init__( combination of `Tensor`s and numbers which can be broadcasted to form a tensor with trailing dimension size of `d_f`. """ - Module.__init__(self) - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) dtype = torch.float device = torch.device("cpu") self.d = d @@ -126,24 +125,13 @@ def forward(self, X: Tensor): X_full = self._construct_X_full(X) return self.acq_func(X_full) - @property - def X_pending(self): - r"""Return the `X_pending` of the base acquisition function.""" - try: - return self.acq_func.X_pending - except (ValueError, AttributeError): - raise ValueError( - f"Base acquisition function {type(self.acq_func).__name__} " - "does not have an `X_pending` attribute." - ) - - @X_pending.setter - def X_pending(self, X_pending: Optional[Tensor]): + def set_X_pending(self, X_pending: Optional[Tensor]): r"""Sets the `X_pending` of the base acquisition function.""" if X_pending is not None: - self.acq_func.X_pending = self._construct_X_full(X_pending) + full_X_pending = self._construct_X_full(X_pending) else: - self.acq_func.X_pending = X_pending + full_X_pending = None + self.acq_func.set_X_pending(full_X_pending) def _construct_X_full(self, X: Tensor) -> Tensor: r"""Constructs the full input for the base acquisition function. diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py index b114362ea9..9ee8f1fee5 100644 --- a/botorch/acquisition/penalized.py +++ b/botorch/acquisition/penalized.py @@ -15,9 +15,8 @@ import torch from botorch.acquisition.acquisition import AcquisitionFunction -from botorch.acquisition.analytic import AnalyticAcquisitionFunction from botorch.acquisition.objective import GenericMCObjective -from botorch.exceptions import UnsupportedError +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from torch import Tensor @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor: return regularization_term -class PenalizedAcquisitionFunction(AcquisitionFunction): +class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper): r"""Single-outcome acquisition function regularized by the given penalty. The usage is similar to: @@ -161,29 +160,16 @@ def __init__( penalty_func: The regularization function. regularization_parameter: Regularization parameter used in optimization. """ - super().__init__(model=raw_acqf.model) - self.raw_acqf = raw_acqf + AcquisitionFunction.__init__(self, model=raw_acqf.model) + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf) self.penalty_func = penalty_func self.regularization_parameter = regularization_parameter def forward(self, X: Tensor) -> Tensor: - raw_value = self.raw_acqf(X=X) + raw_value = self.acq_func(X=X) penalty_term = self.penalty_func(X) return raw_value - self.regularization_parameter * penalty_term - @property - def X_pending(self) -> Optional[Tensor]: - return self.raw_acqf.X_pending - - def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: - if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): - self.raw_acqf.set_X_pending(X_pending=X_pending) - else: - raise UnsupportedError( - "The raw acquisition function is Analytic and does not account " - "for X_pending yet." - ) - def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: r"""Computes the group lasso regularization function for the given point. diff --git a/botorch/acquisition/proximal.py b/botorch/acquisition/proximal.py index 9cd4aed7ad..b1d68edef1 100644 --- a/botorch/acquisition/proximal.py +++ b/botorch/acquisition/proximal.py @@ -15,6 +15,8 @@ import torch from botorch.acquisition import AcquisitionFunction + +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models import ModelListGP from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel @@ -25,7 +27,7 @@ from torch.nn import Module -class ProximalAcquisitionFunction(AcquisitionFunction): +class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper): """A wrapper around AcquisitionFunctions to add proximal weighting of the acquisition function. The acquisition function is weighted via a squared exponential centered at the last training point, @@ -70,9 +72,7 @@ def __init__( beta: If not None, apply a softplus transform to the base acquisition function, allows negative base acquisition function values. """ - Module.__init__(self) - - self.acq_func = acq_function + AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function) model = self.acq_func.model if hasattr(acq_function, "X_pending"): @@ -80,7 +80,6 @@ def __init__( raise UnsupportedError( "Proximal acquisition function requires `X_pending` to be None." ) - self.X_pending = acq_function.X_pending self.register_buffer("proximal_weights", proximal_weights) self.register_buffer( @@ -91,6 +90,12 @@ def __init__( _validate_model(model, proximal_weights) + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + raise UnsupportedError( + "Proximal acquisition function does not support `X_pending`." + ) + @t_batch_mode_transform(expected_q=1, assert_output_shape=False) def forward(self, X: Tensor) -> Tensor: r"""Evaluate base acquisition function with proximal weighting. diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index 486fdd0cff..ccbbf471b2 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -11,7 +11,7 @@ from __future__ import annotations import math -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from botorch.acquisition import analytic, monte_carlo, multi_objective # noqa F401 @@ -22,6 +22,7 @@ MCAcquisitionObjective, PosteriorTransform, ) +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper from botorch.exceptions.errors import UnsupportedError from botorch.models.fully_bayesian import MCMC_DIM from botorch.models.model import Model @@ -253,6 +254,18 @@ def objective(Y: Tensor, X: Optional[Tensor] = None): return -(lb.clamp_max(0.0)) +def isinstance_af( + __obj: object, + __class_or_tuple: Union[type, tuple[Union[type, tuple[Any, ...]], ...]], +) -> bool: + r"""A variant of isinstance first checks for the acq_func attribute on wrapped acquisition functions.""" + if isinstance(__obj, AbstractAcquisitionFunctionWrapper): + isinstance_base_af = isinstance(__obj.acq_func, __class_or_tuple) + else: + isinstance_base_af = False + return isinstance_base_af or isinstance(__obj, __class_or_tuple) + + def is_nonnegative(acq_function: AcquisitionFunction) -> bool: r"""Determine whether a given acquisition function is non-negative. @@ -267,7 +280,7 @@ def is_nonnegative(acq_function: AcquisitionFunction) -> bool: >>> qEI = qExpectedImprovement(model, best_f=0.1) >>> is_nonnegative(qEI) # returns True """ - return isinstance( + return isinstance_af( acq_function, ( analytic.ExpectedImprovement, diff --git a/botorch/acquisition/wrapper.py b/botorch/acquisition/wrapper.py new file mode 100644 index 0000000000..08dfbd2849 --- /dev/null +++ b/botorch/acquisition/wrapper.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +A wrapper classes around AcquisitionFunctions to modify inputs and outputs. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from botorch.acquisition.acquisition import AcquisitionFunction +from torch import Tensor +from torch.nn import Module + + +class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC): + r"""Abstract acquisition wrapper.""" + + def __init__(self, acq_function: AcquisitionFunction) -> None: + Module.__init__(self) + self.acq_func = acq_function + + @property + def X_pending(self) -> Optional[Tensor]: + r"""Return the `X_pending` of the base acquisition function.""" + try: + return self.acq_func.X_pending + except (ValueError, AttributeError): + raise ValueError( + f"Base acquisition function {type(self.acq_func).__name__} " + "does not have an `X_pending` attribute." + ) + + def set_X_pending(self, X_pending: Optional[Tensor]) -> None: + r"""Sets the `X_pending` of the base acquisition function.""" + self.acq_func.set_X_pending(X_pending) + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + r"""Evaluate the wrapped acquisition function on the candidate set X. + + Args: + X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim + design points each. + + Returns: + A `(b)`-dim Tensor of acquisition function values at the given + design points `X`. + """ + pass # pragma: no cover diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index 79f529826a..a3c5eaeb5a 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -21,6 +21,11 @@ Analytic Acquisition Function API .. autoclass:: AnalyticAcquisitionFunction :members: +Acquisition Function Wrapper API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.wrapper + :members: + Cached Cholesky Acquisition Function API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.cached_cholesky @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions .. automodule:: botorch.acquisition.multi_objective.analytic :members: :exclude-members: MultiObjectiveAnalyticAcquisitionFunction - + Multi-Objective Joint Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.joint_entropy_search @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.multi_fidelity :members: - + Multi-Objective Predictive Entropy Search Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search diff --git a/test/acquisition/test_fixed_feature.py b/test/acquisition/test_fixed_feature.py index 8dcc02f1df..b8f570e7e1 100644 --- a/test/acquisition/test_fixed_feature.py +++ b/test/acquisition/test_fixed_feature.py @@ -87,7 +87,7 @@ def test_fixed_features(self): qEI_ff.set_X_pending(X_pending[..., :-1]) self.assertAllClose(qEI.X_pending, X_pending) # test setting to None - qEI_ff.X_pending = None + qEI_ff.set_X_pending(None) self.assertIsNone(qEI_ff.X_pending) # test gradient diff --git a/test/acquisition/test_proximal.py b/test/acquisition/test_proximal.py index 795daa1b34..e17536ddd0 100644 --- a/test/acquisition/test_proximal.py +++ b/test/acquisition/test_proximal.py @@ -209,9 +209,15 @@ def test_proximal(self): # test for x_pending points pending_acq = DummyAcquisitionFunction(model) - pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype)) + X_pending = torch.rand(3, 3, device=self.device, dtype=dtype) + pending_acq.set_X_pending(X_pending) with self.assertRaises(UnsupportedError): ProximalAcquisitionFunction(pending_acq, proximal_weights) + # test setting pending points + pending_acq.set_X_pending(None) + af = ProximalAcquisitionFunction(pending_acq, proximal_weights) + with self.assertRaises(UnsupportedError): + af.set_X_pending(X_pending) # test model with multi-batch training inputs train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index d12b5f6da4..39b8017ea2 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -8,7 +8,8 @@ from unittest import mock import torch -from botorch.acquisition import monte_carlo +from botorch.acquisition import analytic, monte_carlo, multi_objective +from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction from botorch.acquisition.multi_objective import ( MCMultiOutputObjective, monte_carlo as moo_monte_carlo, @@ -18,10 +19,13 @@ MCAcquisitionObjective, ScalarizedPosteriorTransform, ) +from botorch.acquisition.proximal import ProximalAcquisitionFunction from botorch.acquisition.utils import ( expand_trace_observations, get_acquisition_function, get_infeasible_cost, + is_nonnegative, + isinstance_af, project_to_sample_points, project_to_target_fidelity, prune_inferior_points, @@ -606,6 +610,61 @@ def test_get_infeasible_cost(self): self.assertAllClose(M4, torch.tensor([1.0], **tkwargs)) +class TestIsNonnegative(BotorchTestCase): + def test_is_nonnegative(self): + nonneg_afs = ( + analytic.ExpectedImprovement, + analytic.ConstrainedExpectedImprovement, + analytic.ProbabilityOfImprovement, + analytic.NoisyExpectedImprovement, + monte_carlo.qExpectedImprovement, + monte_carlo.qNoisyExpectedImprovement, + monte_carlo.qProbabilityOfImprovement, + multi_objective.analytic.ExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qExpectedHypervolumeImprovement, + multi_objective.monte_carlo.qNoisyExpectedHypervolumeImprovement, + ) + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + with mock.patch( + "botorch.acquisition.utils.isinstance_af", return_value=True + ) as mock_isinstance_af: + self.assertTrue(is_nonnegative(acq_function=acq_func)) + mock_isinstance_af.assert_called_once() + cargs, _ = mock_isinstance_af.call_args + self.assertIs(cargs[0], acq_func) + self.assertEqual(cargs[1], nonneg_afs) + acq_func = analytic.UpperConfidenceBound(model=mm, beta=2.0) + self.assertFalse(is_nonnegative(acq_function=acq_func)) + + +class TestIsinstanceAf(BotorchTestCase): + def test_isinstance_af(self): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, device=self.device), + variance=torch.ones(1, 1, device=self.device), + ) + ) + acq_func = analytic.ExpectedImprovement(model=mm, best_f=-1.0) + self.assertTrue(isinstance_af(acq_func, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(acq_func, analytic.UpperConfidenceBound)) + wrapped_af = FixedFeatureAcquisitionFunction( + acq_function=acq_func, d=2, columns=[1], values=[0.0] + ) + # test base af class + self.assertTrue(isinstance_af(wrapped_af, analytic.ExpectedImprovement)) + self.assertFalse(isinstance_af(wrapped_af, analytic.UpperConfidenceBound)) + # test wrapper class + self.assertTrue(isinstance_af(wrapped_af, FixedFeatureAcquisitionFunction)) + self.assertFalse(isinstance_af(wrapped_af, ProximalAcquisitionFunction)) + + class TestPruneInferiorPoints(BotorchTestCase): def test_prune_inferior_points(self): for dtype in (torch.float, torch.double): diff --git a/test/acquisition/test_wrapper.py b/test/acquisition/test_wrapper.py new file mode 100644 index 0000000000..e35175fb9b --- /dev/null +++ b/test/acquisition/test_wrapper.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper +from botorch.exceptions.errors import UnsupportedError +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class DummyWrapper(AbstractAcquisitionFunctionWrapper): + def forward(self, X): + return self.acq_func(X) + + +class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase): + def test_abstract_acquisition_function_wrapper(self): + for dtype in (torch.float, torch.double): + mm = MockModel( + MockPosterior( + mean=torch.rand(1, 1, dtype=dtype, device=self.device), + variance=torch.ones(1, 1, dtype=dtype, device=self.device), + ) + ) + acq_func = ExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIs(wrapped_af.acq_func, acq_func) + # test forward + X = torch.rand(1, 1, dtype=dtype, device=self.device) + with torch.no_grad(): + wrapped_val = wrapped_af(X) + af_val = acq_func(X) + self.assertEqual(wrapped_val.item(), af_val.item()) + + # test X_pending + with self.assertRaises(ValueError): + self.assertIsNone(wrapped_af.X_pending) + with self.assertRaises(UnsupportedError): + wrapped_af.set_X_pending(X) + acq_func = qExpectedImprovement(model=mm, best_f=-1.0) + wrapped_af = DummyWrapper(acq_function=acq_func) + self.assertIsNone(wrapped_af.X_pending) + wrapped_af.set_X_pending(X) + self.assertTrue(torch.equal(X, wrapped_af.X_pending)) + self.assertTrue(torch.equal(X, acq_func.X_pending)) + wrapped_af.set_X_pending(None) + self.assertIsNone(wrapped_af.X_pending) + self.assertIsNone(acq_func.X_pending)