Skip to content

Commit

Permalink
acquisition function wrapper (pytorch#1532)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1532

Add a wrapper for modifying inputs/outputs. This is useful for not only probabilistic reparameterization, but will also simplify other integrated AFs (e.g. MCMC) as well as fixed feature AFs and things like prior-guided AFs

Differential Revision: D41629186

fbshipit-source-id: c52722b2946207e219ad5f49e6fa314706cdd953
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 8, 2023
1 parent 0a0c2fb commit 3bd7129
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 47 deletions.
26 changes: 7 additions & 19 deletions botorch/acquisition/fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor
from torch.nn import Module


class FixedFeatureAcquisitionFunction(AcquisitionFunction):
class FixedFeatureAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AquisitionFunctions to fix a subset of features.
Example:
Expand Down Expand Up @@ -56,8 +56,7 @@ def __init__(
combination of `Tensor`s and numbers which can be broadcasted
to form a tensor with trailing dimension size of `d_f`.
"""
Module.__init__(self)
self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
dtype = torch.float
device = torch.device("cpu")
self.d = d
Expand Down Expand Up @@ -126,24 +125,13 @@ def forward(self, X: Tensor):
X_full = self._construct_X_full(X)
return self.acq_func(X_full)

@property
def X_pending(self):
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

@X_pending.setter
def X_pending(self, X_pending: Optional[Tensor]):
def set_X_pending(self, X_pending: Optional[Tensor]):
r"""Sets the `X_pending` of the base acquisition function."""
if X_pending is not None:
self.acq_func.X_pending = self._construct_X_full(X_pending)
full_X_pending = self._construct_X_full(X_pending)
else:
self.acq_func.X_pending = X_pending
full_X_pending = None
self.acq_func.set_X_pending(full_X_pending)

def _construct_X_full(self, X: Tensor) -> Tensor:
r"""Constructs the full input for the base acquisition function.
Expand Down
24 changes: 5 additions & 19 deletions botorch/acquisition/penalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

import torch
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.acquisition.objective import GenericMCObjective
from botorch.exceptions import UnsupportedError
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from torch import Tensor


Expand Down Expand Up @@ -139,7 +138,7 @@ def forward(self, X: Tensor) -> Tensor:
return regularization_term


class PenalizedAcquisitionFunction(AcquisitionFunction):
class PenalizedAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
r"""Single-outcome acquisition function regularized by the given penalty.
The usage is similar to:
Expand All @@ -161,29 +160,16 @@ def __init__(
penalty_func: The regularization function.
regularization_parameter: Regularization parameter used in optimization.
"""
super().__init__(model=raw_acqf.model)
self.raw_acqf = raw_acqf
AcquisitionFunction.__init__(self, model=raw_acqf.model)
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=raw_acqf)
self.penalty_func = penalty_func
self.regularization_parameter = regularization_parameter

def forward(self, X: Tensor) -> Tensor:
raw_value = self.raw_acqf(X=X)
raw_value = self.acq_func(X=X)
penalty_term = self.penalty_func(X)
return raw_value - self.regularization_parameter * penalty_term

@property
def X_pending(self) -> Optional[Tensor]:
return self.raw_acqf.X_pending

def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None:
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction):
self.raw_acqf.set_X_pending(X_pending=X_pending)
else:
raise UnsupportedError(
"The raw acquisition function is Analytic and does not account "
"for X_pending yet."
)


def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor:
r"""Computes the group lasso regularization function for the given point.
Expand Down
15 changes: 10 additions & 5 deletions botorch/acquisition/proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.models import ModelListGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel
Expand All @@ -25,7 +27,7 @@
from torch.nn import Module


class ProximalAcquisitionFunction(AcquisitionFunction):
class ProximalAcquisitionFunction(AbstractAcquisitionFunctionWrapper):
"""A wrapper around AcquisitionFunctions to add proximal weighting of the
acquisition function. The acquisition function is
weighted via a squared exponential centered at the last training point,
Expand Down Expand Up @@ -70,17 +72,14 @@ def __init__(
beta: If not None, apply a softplus transform to the base acquisition
function, allows negative base acquisition function values.
"""
Module.__init__(self)

self.acq_func = acq_function
AbstractAcquisitionFunctionWrapper.__init__(self, acq_function=acq_function)
model = self.acq_func.model

if hasattr(acq_function, "X_pending"):
if acq_function.X_pending is not None:
raise UnsupportedError(
"Proximal acquisition function requires `X_pending` to be None."
)
self.X_pending = acq_function.X_pending

self.register_buffer("proximal_weights", proximal_weights)
self.register_buffer(
Expand All @@ -91,6 +90,12 @@ def __init__(

_validate_model(model, proximal_weights)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
raise UnsupportedError(
"Proximal acquisition function does not support `X_pending`."
)

@t_batch_mode_transform(expected_q=1, assert_output_shape=False)
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate base acquisition function with proximal weighting.
Expand Down
55 changes: 55 additions & 0 deletions botorch/acquisition/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

r"""
A wrapper classes around AcquisitionFunctions to modify inputs and outputs.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional

from botorch.acquisition.acquisition import AcquisitionFunction
from torch import Tensor
from torch.nn import Module


class AbstractAcquisitionFunctionWrapper(AcquisitionFunction, ABC):
r"""Abstract acquisition wrapper."""

def __init__(self, acq_function: AcquisitionFunction) -> None:
Module.__init__(self)
self.acq_func = acq_function

@property
def X_pending(self) -> Optional[Tensor]:
r"""Return the `X_pending` of the base acquisition function."""
try:
return self.acq_func.X_pending
except (ValueError, AttributeError):
raise ValueError(
f"Base acquisition function {type(self.acq_func).__name__} "
"does not have an `X_pending` attribute."
)

def set_X_pending(self, X_pending: Optional[Tensor]) -> None:
r"""Sets the `X_pending` of the base acquisition function."""
self.acq_func.set_X_pending(X_pending)

@abstractmethod
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the wrapped acquisition function on the candidate set X.
Args:
X: A `(b) x q x d`-dim Tensor of `(b)` t-batches with `q` `d`-dim
design points each.
Returns:
A `(b)`-dim Tensor of acquisition function values at the given
design points `X`.
"""
pass # pragma: no cover
9 changes: 7 additions & 2 deletions sphinx/source/acquisition.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ Analytic Acquisition Function API
.. autoclass:: AnalyticAcquisitionFunction
:members:

Acquisition Function Wrapper API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.wrapper
:members:

Cached Cholesky Acquisition Function API
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.cached_cholesky
Expand Down Expand Up @@ -65,7 +70,7 @@ Multi-Objective Analytic Acquisition Functions
.. automodule:: botorch.acquisition.multi_objective.analytic
:members:
:exclude-members: MultiObjectiveAnalyticAcquisitionFunction

Multi-Objective Joint Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.joint_entropy_search
Expand All @@ -86,7 +91,7 @@ Multi-Objective Multi-Fidelity Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.multi_fidelity
:members:

Multi-Objective Predictive Entropy Search Acquisition Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.acquisition.multi_objective.predictive_entropy_search
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/test_fixed_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_fixed_features(self):
qEI_ff.set_X_pending(X_pending[..., :-1])
self.assertAllClose(qEI.X_pending, X_pending)
# test setting to None
qEI_ff.X_pending = None
qEI_ff.set_X_pending(None)
self.assertIsNone(qEI_ff.X_pending)

# test gradient
Expand Down
8 changes: 7 additions & 1 deletion test/acquisition/test_proximal.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,15 @@ def test_proximal(self):

# test for x_pending points
pending_acq = DummyAcquisitionFunction(model)
pending_acq.set_X_pending(torch.rand(3, 3, device=self.device, dtype=dtype))
X_pending = torch.rand(3, 3, device=self.device, dtype=dtype)
pending_acq.set_X_pending(X_pending)
with self.assertRaises(UnsupportedError):
ProximalAcquisitionFunction(pending_acq, proximal_weights)
# test setting pending points
pending_acq.set_X_pending(None)
af = ProximalAcquisitionFunction(pending_acq, proximal_weights)
with self.assertRaises(UnsupportedError):
af.set_X_pending(X_pending)

# test model with multi-batch training inputs
train_X = torch.rand(5, 2, 3, device=self.device, dtype=dtype)
Expand Down
52 changes: 52 additions & 0 deletions test/acquisition/test_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.wrapper import AbstractAcquisitionFunctionWrapper
from botorch.exceptions.errors import UnsupportedError
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


class DummyWrapper(AbstractAcquisitionFunctionWrapper):
def forward(self, X):
return self.acq_func(X)


class TestAbstractAcquisitionFunctionWrapper(BotorchTestCase):
def test_abstract_acquisition_function_wrapper(self):
for dtype in (torch.float, torch.double):
mm = MockModel(
MockPosterior(
mean=torch.rand(1, 1, dtype=dtype, device=self.device),
variance=torch.ones(1, 1, dtype=dtype, device=self.device),
)
)
acq_func = ExpectedImprovement(model=mm, best_f=-1.0)
wrapped_af = DummyWrapper(acq_function=acq_func)
self.assertIs(wrapped_af.acq_func, acq_func)
# test forward
X = torch.rand(1, 1, dtype=dtype, device=self.device)
with torch.no_grad():
wrapped_val = wrapped_af(X)
af_val = acq_func(X)
self.assertEqual(wrapped_val.item(), af_val.item())

# test X_pending
with self.assertRaises(ValueError):
self.assertIsNone(wrapped_af.X_pending)
with self.assertRaises(UnsupportedError):
wrapped_af.set_X_pending(X)
acq_func = qExpectedImprovement(model=mm, best_f=-1.0)
wrapped_af = DummyWrapper(acq_function=acq_func)
self.assertIsNone(wrapped_af.X_pending)
wrapped_af.set_X_pending(X)
self.assertTrue(torch.equal(X, wrapped_af.X_pending))
self.assertTrue(torch.equal(X, acq_func.X_pending))
wrapped_af.set_X_pending(None)
self.assertIsNone(wrapped_af.X_pending)
self.assertIsNone(acq_func.X_pending)

0 comments on commit 3bd7129

Please sign in to comment.