Skip to content

Pathwise Thomspon sampling for ensemble models #2877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 127 additions & 24 deletions botorch/acquisition/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

import torch
from botorch.acquisition.analytic import AcquisitionFunction
from botorch.acquisition.objective import PosteriorTransform
from botorch.acquisition.objective import (
IdentityMCObjective,
MCAcquisitionObjective,
PosteriorTransform,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import Model
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
from botorch.utils.transforms import t_batch_mode_transform
from botorch.utils.transforms import is_ensemble, t_batch_mode_transform
from torch import Tensor


Expand All @@ -32,55 +38,152 @@ class PathwiseThompsonSampling(AcquisitionFunction):
def __init__(
self,
model: Model,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
) -> None:
r"""Single-outcome TS.
If using a multi-output `model`, the acquisition function requires either an
`objective` or a `posterior_transform` that transforms the multi-output
posterior samples to single-output posterior samples.
objective: An MCAcquisitionObjective. Defaults to `IdentityMCObjective`.
Args:
model: A fitted GP model.
posterior_transform: A PosteriorTransform. If using a multi-output model,
a PosteriorTransform that transforms the multi-output posterior into a
single-output posterior is required.
objective: The MCAcquisitionObjective under which the samples are
evaluated. Defaults to `IdentityMCObjective()`.
posterior_transform: An optional PosteriorTransform.
"""
if model._is_fully_bayesian:
raise NotImplementedError(
"PathwiseThompsonSampling is not supported for fully Bayesian models",
)

super().__init__(model=model)
self.batch_size: int | None = None

def redraw(self) -> None:
self.samples: GenericDeterministicModel | None = None
self.ensemble_indices: Tensor | None = None

# NOTE: This conditional block is copied from MCAcquisitionFunction, we should
# consider inherting from it and e.g. getting the X_pending logic as well.
if objective is None and model.num_outputs != 1:
if posterior_transform is None:
raise UnsupportedError(
"Must specify an objective or a posterior transform when using "
"a multi-output model."
)
elif not posterior_transform.scalarize:
raise UnsupportedError(
"If using a multi-output model without an objective, "
"posterior_transform must scalarize the output."
)
if objective is None:
objective = IdentityMCObjective()
self.objective = objective
self.posterior_transform = posterior_transform

def redraw(self, batch_size: int) -> None:
sample_shape = (batch_size,)
self.samples = get_matheron_path_model(
model=self.model, sample_shape=torch.Size([self.batch_size])
model=self.model, sample_shape=torch.Size(sample_shape)
)
if is_ensemble(self.model):
# the ensembling dimension is assumed to be part of the batch shape
model_batch_shape = self.model.batch_shape
if len(model_batch_shape) > 1:
raise NotImplementedError(
"Ensemble models with more than one ensemble dimension are not "
"yet supported."
)
num_ensemble = model_batch_shape[0]
# ensemble_indices is cached here to ensure that the acquisition function
# becomes deterministic for the same input and can be optimized with LBFGS.
# ensemble_indices is used in select_from_ensemble_models.
self.ensemble_indices = torch.randint(
0,
num_ensemble,
(*sample_shape, 1, self.model.num_outputs),
)

@t_batch_mode_transform()
def forward(self, X: Tensor) -> Tensor:
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
Args:
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
Returns:
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
evaluations on the posterior sample draws.
A `batch_shape`-dim tensor of evaluations on the posterior sample draws,
where the samples are summed over the q-batch dimension.
"""
batch_size = X.shape[-2]
q_dim = -2
objective_values = self._pathwise_forward(X) # batch_shape x q
# NOTE: The current implementation sums over the q-batch dimension, which means
# that we are optimizing the sum of independent Thompson samples. In the future,
# we can leverage *batched* L-BFGS optimization, rather than summing over the q
# dimension, which will guarantee descent steps for all members of the batch
# through batch-member-specific learning rate selection.
return objective_values.sum(-1) # batch_shape

def _pathwise_forward(self, X: Tensor) -> Tensor:
"""Evaluate the pathwise posterior sample draws on the candidate set X.
Args:
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
Returns:
A `batch_shape x q`-dim tensor of evaluations on the posterior sample draws.
"""
batch_size = X.shape[-2]
# batch_shape x q x 1 x d
X = X.unsqueeze(-2)
if self.batch_size is None:
if self.samples is None:
self.batch_size = batch_size
self.redraw()
elif self.batch_size != batch_size:
self.redraw(batch_size=batch_size)

if self.batch_size != batch_size:
raise ValueError(
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
)
# batch_shape x q [x num_ensembles] x 1 x m
posterior_values = self.samples(X)
# batch_shape x q [x num_ensembles] x m
posterior_values = posterior_values.squeeze(-2)

# posterior_values.shape post-squeeze:
# batch_shape x q x m
posterior_values = self.samples(X).squeeze(-2)
# sum over batch dim and squeeze num_objectives dim (-1)
return posterior_values.sum(q_dim).squeeze(-1)
posterior_values = self.select_from_ensemble_models(values=posterior_values)

if self.posterior_transform:
posterior_values = self.posterior_transform.evaluate(posterior_values)
# objective removes the `m` dimension
objective_values = self.objective(posterior_values) # batch_shape x q
return objective_values

def select_from_ensemble_models(self, values: Tensor):
"""Subselecting a value associated with a single sample in the ensemble for each
element of samples that is not associated with an ensemble dimension.
NOTE: 1) uses `self.model` and `is_ensemble` to determine whether or not an
ensembling dimension is present. 2) uses `self.ensemble_indices` to select the
value associated with a single sample in the ensemble. `ensemble_indices`
contains uniformly randomly sample indices for each element of the ensemble, but
is cached to make the evaluation of the acquisition function deterministic.
Args:
values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
Returns:
A`batch_shape x num_draws x q x m`-dim where each element is contains a
single sample from the ensemble, selected with `self.ensemble_indices`.
"""
if not is_ensemble(self.model):
return values

ensemble_dim = -2
# `ensemble_indices` are fixed so that the acquisition function becomes
# deterministic for the same input and can be optimized with LBFGS.
# ensemble indices have shape num_paths x 1 x m
self.ensemble_indices = self.ensemble_indices.to(device=values.device)
index = self.ensemble_indices
input_batch_shape = values.shape[:-3]
index = index.expand(*input_batch_shape, *index.shape)
# samples is batch_shape x q x num_ensemble x m
values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index)
return values_wo_ensemble.squeeze(
ensemble_dim
) # removing the ensemble dimension
6 changes: 5 additions & 1 deletion botorch/acquisition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,11 @@ def get_optimal_samples(
else:
sample_transform = None

paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
paths = get_matheron_path_model(
model=model,
sample_shape=torch.Size([num_optima]),
ensemble_as_batch=True,
)
optimal_inputs, optimal_outputs = optimize_posterior_samples(
paths=paths,
bounds=bounds,
Expand Down
27 changes: 25 additions & 2 deletions botorch/models/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@ class GenericDeterministicModel(DeterministicModel):
>>> model = GenericDeterministicModel(f)
"""

def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
def __init__(
self,
f: Callable[[Tensor], Tensor],
num_outputs: int = 1,
batch_shape: torch.Size | None = None,
) -> None:
r"""
Args:
f: A callable mapping a `batch_shape x n x d`-dim input tensor `X`
Expand All @@ -75,6 +80,12 @@ def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
super().__init__()
self._f = f
self._num_outputs = num_outputs
self._batch_shape = batch_shape

@property
def batch_shape(self) -> torch.Size | None:
r"""The batch shape of the model."""
return self._batch_shape

def subset_output(self, idcs: list[int]) -> GenericDeterministicModel:
r"""Subset the model along the output dimension.
Expand All @@ -100,7 +111,19 @@ def forward(self, X: Tensor) -> Tensor:
Returns:
A `batch_shape x n x m`-dimensional output tensor.
"""
return self._f(X)
Y = self._f(X)
batch_shape = Y.shape[:-2]
# allowing for old behavior of not specifying the batch_shape
if self.batch_shape is not None:
try:
torch.broadcast_shapes(self.batch_shape, batch_shape)
except RuntimeError:
raise ValueError(
"GenericDeterministicModel was initialized with batch_shape="
f"{self.batch_shape=} but the output of f has a batch_shape="
f"{batch_shape=} that is not broadcastable with it."
)
return Y


class AffineDeterministicModel(DeterministicModel):
Expand Down
71 changes: 69 additions & 2 deletions botorch/sampling/pathwise/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from abc import ABC
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping
from typing import Any

Expand All @@ -24,6 +24,16 @@
class SamplePath(ABC, TransformedModuleMixin, Module):
r"""Abstract base class for Botorch sample paths."""

@abstractmethod
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
"""Sets whether the ensemble dimension is considered as a batch dimension.

Args:
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
dimension or not.
"""
pass # pragma: no cover


class PathDict(SamplePath):
r"""A dictionary of SamplePaths."""
Expand Down Expand Up @@ -84,6 +94,16 @@ def __getitem__(self, key: str) -> SamplePath:
def __setitem__(self, key: str, val: SamplePath) -> None:
self.paths[key] = val

def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
"""Sets whether the ensemble dimension is considered as a batch dimension.

Args:
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
dimension or not.
"""
for path in self.paths.values():
path.set_ensemble_as_batch(ensemble_as_batch)


class PathList(SamplePath):
r"""A list of SamplePaths."""
Expand Down Expand Up @@ -136,6 +156,16 @@ def __getitem__(self, key: int) -> SamplePath:
def __setitem__(self, key: int, val: SamplePath) -> None:
self.paths[key] = val

def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
"""Sets whether the ensemble dimension is considered as a batch dimension.

Args:
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
dimension or not.
"""
for path in self.paths:
path.set_ensemble_as_batch(ensemble_as_batch)


class GeneralizedLinearPath(SamplePath):
r"""A sample path in the form of a generalized linear model."""
Expand All @@ -147,6 +177,8 @@ def __init__(
bias_module: Module | None = None,
input_transform: TInputTransform | None = None,
output_transform: TOutputTransform | None = None,
is_ensemble: bool = False,
ensemble_as_batch: bool = False,
):
r"""Initializes a GeneralizedLinearPath instance.

Expand All @@ -157,10 +189,17 @@ def __init__(

Args:
feature_map: A map used to featurize the module's inputs.
weight: A tensor of weights used to combine input features.
weight: A tensor of weights used to combine input features. When generated
with `draw_kernel_feature_paths`, `weight` is a Tensor with the shape
`sample_shape x batch_shape x num_outputs`.
bias_module: An optional module used to define additive offsets.
input_transform: An optional input transform for the module.
output_transform: An optional output transform for the module.
is_ensemble: Whether the associated model is an ensemble model or not.
ensemble_as_batch: Whether the ensemble dimension is added as a batch
dimension or not. If `True`, the ensemble dimension is treated as a
batch dimension, which allows for the joint optimization of all members
of the ensemble.
"""
super().__init__()
self.feature_map = feature_map
Expand All @@ -170,8 +209,36 @@ def __init__(
self.bias_module = bias_module
self.input_transform = input_transform
self.output_transform = output_transform
self.is_ensemble = is_ensemble
self.ensemble_as_batch = ensemble_as_batch

def forward(self, x: Tensor, **kwargs) -> Tensor:
"""Evaluates the path.

Args:
x: The input tensor of shape `batch_shape x [num_ensemble x] q x d`, where
`num_ensemble` is the number of ensemble members and is required to
*only* be included if `is_ensemble=True` and `ensemble_as_batch=True`.
kwargs: Additional keyword arguments passed to the feature map.

Returns:
A tensor of shape `batch_shape x [num_ensemble x] q x m`, where `m` is the
number of outputs, where `num_ensemble` is only included if `is_ensemble`
is `True`, and regardless of whether `ensemble_as_batch` is `True` or not.
"""
if self.is_ensemble and not self.ensemble_as_batch:
# assuming that the ensembling dimension is added after (n, d), but
# before the other batch dimensions, starting from the left.
x = x.unsqueeze(-3)
feat = self.feature_map(x, **kwargs)
out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1)
return out if self.bias_module is None else out + self.bias_module(x)

def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
"""Sets whether the ensemble dimension is considered as a batch dimension.

Args:
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
dimension or not.
"""
self.ensemble_as_batch = ensemble_as_batch
Loading
Loading