Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

acquisition function wrapper #1532

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 7 additions & 19 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.

Example:
Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d
Expand Down Expand Up @@ -126,24 +125,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
def set_X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.

The usage is similar to:
Expand All @@ -161,29 +160,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
15 changes: 10 additions & 5 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models import ModelListGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
Expand All @@ -25,7 +27,7 @@
from torch.nn import Module


class ProximalAcquisitionFunction(AcquisitionFunction):
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. The acquisition function is
weighted via a squared exponential centered at the last training point,
Expand Down Expand Up @@ -70,17 +72,14 @@ def __init__(
beta: If not None, apply a softplus transform to the base acquisition
function, allows negative base acquisition function values.
"""
Module.__init__(self)

self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
model = self.acq_func.model

if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending

self.register_buffer("proximal_weights", proximal_weights)
self.register_buffer(
Expand All @@ -91,6 +90,12 @@ def __init__(

_validate_model(model, proximal_weights)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
raise UnsupportedError(
"Proximal acquisition function does not support `X_pending`."
)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Expand Down
55 changes: 55 additions & 0 deletions botorch/acquisition/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module


class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
r"""Abstract acquisition wrapper."""

def __init__(self, acq_function: AcquisitionFunction) -> None:
Module.__init__(self)
self.acq_func = acq_function

@property
def X_pending(self) -> Optional[Tensor]:
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
self.acq_func.set_X_pending(X_pending)

@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the wrapped acquisition function on the candidate set X.

Args:
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
design points each.

Returns:
A `(b)`-dim Tensor of acquisition function values at the given
design points `X`.
"""
pass # pragma: no cover
9 changes: 7 additions & 2 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Analytic Acquisition Function API
.. autoclass:: AnalyticAcquisitionFunction
:members:

Acquisition Function Wrapper API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.wrapper
:members:

Cached Cholesky Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.cached_cholesky
Expand Down Expand Up @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.analytic
:members:
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction

Multi-Objective Joint Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
Expand All @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
:members:

Multi-Objective Predictive Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fixed_features(self):
qEI_ff.set_X_pending(X_pending[..., :-1])
self.assertAllClose(qEI.X_pending, X_pending)
# test setting to None
qEI_ff.X_pending = None
qEI_ff.set_X_pending(None)
self.assertIsNone(qEI_ff.X_pending)

# test gradient
Expand Down
8 changes: 7 additions & 1 deletion test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,15 @@ def test_proximal(self):

# test for x_pending points
pending_acq = DummyAcquisitionFunction(model)
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
pending_acq.set_X_pending(X_pending)
with self.assertRaises(UnsupportedError):
ProximalAcquisitionFunction(pending_acq, proximal_weights)
# test setting pending points
pending_acq.set_X_pending(None)
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
with self.assertRaises(UnsupportedError):
af.set_X_pending(X_pending)

# test model with multi-batch training inputs
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)
Expand Down
52 changes: 52 additions & 0 deletions test/acquisition/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


class DummyWrapper(AbstractAcquisitionFunctionWrapper):
def forward(self, X):
return self.acq_func(X)


class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase):
def test_abstract_acquisition_function_wrapper(self):
for dtype in (torch.float, torch.double):
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, dtype=dtype, device=self.device),
variance=torch.ones(1, 1, dtype=dtype, device=self.device),
)
)
acq_func = ExpectedImprovement(model=mm, best_f=-1.0)
wrapped_af = DummyWrapper(acq_function=acq_func)
self.assertIs(wrapped_af.acq_func, acq_func)
# test forward
X = torch.rand(1, 1, dtype=dtype, device=self.device)
with torch.no_grad():
wrapped_val = wrapped_af(X)
af_val = acq_func(X)
self.assertEqual(wrapped_val.item(), af_val.item())

# test X_pending
with self.assertRaises(ValueError):
self.assertIsNone(wrapped_af.X_pending)
with self.assertRaises(UnsupportedError):
wrapped_af.set_X_pending(X)
acq_func = qExpectedImprovement(model=mm, best_f=-1.0)
wrapped_af = DummyWrapper(acq_function=acq_func)
self.assertIsNone(wrapped_af.X_pending)
wrapped_af.set_X_pending(X)
self.assertTrue(torch.equal(X, wrapped_af.X_pending))
self.assertTrue(torch.equal(X, acq_func.X_pending))
wrapped_af.set_X_pending(None)
self.assertIsNone(wrapped_af.X_pending)
self.assertIsNone(acq_func.X_pending)