Skip to content

Commit 842cbfb

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Pathwise Thomspon sampling for ensemble models (#2877)
Summary: Pull Request resolved: #2877 This commit adds support for pathwise Thompson sampling for ensemble models, including fully Bayesian SAAS models. Differential Revision: D75990595
1 parent 6970f9b commit 842cbfb

File tree

14 files changed

+395
-111
lines changed

14 files changed

+395
-111
lines changed

botorch/acquisition/thompson_sampling.py

Lines changed: 109 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
import torch
88
from botorch.acquisition.analytic import AcquisitionFunction
9-
from botorch.acquisition.objective import PosteriorTransform
9+
from botorch.acquisition.objective import (
10+
IdentityMCObjective,
11+
MCAcquisitionObjective,
12+
PosteriorTransform,
13+
)
14+
from botorch.exceptions.errors import UnsupportedError
15+
from botorch.models.deterministic import GenericDeterministicModel
1016
from botorch.models.model import Model
1117
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
12-
from botorch.utils.transforms import t_batch_mode_transform
18+
from botorch.utils.transforms import is_ensemble, t_batch_mode_transform
1319
from torch import Tensor
1420

1521

@@ -32,7 +38,9 @@ class PathwiseThompsonSampling(AcquisitionFunction):
3238
def __init__(
3339
self,
3440
model: Model,
41+
objective: MCAcquisitionObjective | None = None,
3542
posterior_transform: PosteriorTransform | None = None,
43+
samples: GenericDeterministicModel | None = None,
3644
) -> None:
3745
r"""Single-outcome TS.
3846
@@ -41,46 +49,125 @@ def __init__(
4149
posterior_transform: A PosteriorTransform. If using a multi-output model,
4250
a PosteriorTransform that transforms the multi-output posterior into a
4351
single-output posterior is required.
52+
samples: A GenericDeterministicModel that evaluates a set of posterior
53+
sample paths.
4454
"""
45-
if model._is_fully_bayesian:
46-
raise NotImplementedError(
47-
"PathwiseThompsonSampling is not supported for fully Bayesian models",
48-
)
4955

5056
super().__init__(model=model)
51-
self.batch_size: int | None = None
52-
53-
def redraw(self) -> None:
57+
self.batch_size: int | None = None if samples is None else samples.batch_shape
58+
59+
# NOTE: This conditional block is copied from MCAcquisitionFunction, we should
60+
# consider inherting from it and e.g. getting the X_pending logic as well.
61+
if objective is None and model.num_outputs != 1:
62+
if posterior_transform is None:
63+
raise UnsupportedError(
64+
"Must specify an objective or a posterior transform when using "
65+
"a multi-output model."
66+
)
67+
elif not posterior_transform.scalarize:
68+
raise UnsupportedError(
69+
"If using a multi-output model without an objective, "
70+
"posterior_transform must scalarize the output."
71+
)
72+
if objective is None:
73+
objective = IdentityMCObjective()
74+
self.objective = objective
75+
self.posterior_transform = posterior_transform
76+
self.samples: GenericDeterministicModel | None = samples
77+
78+
def redraw(self, batch_size: int) -> None:
79+
sample_shape = (batch_size,)
5480
self.samples = get_matheron_path_model(
55-
model=self.model, sample_shape=torch.Size([self.batch_size])
81+
model=self.model, sample_shape=torch.Size(sample_shape)
5682
)
83+
if is_ensemble(self.model):
84+
# the ensembling dimension is assumed to be part of the batch shape
85+
# could add a dedicated proporty to keep track of the ensembling dimension
86+
# i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP
87+
model_batch_shape = self.model.batch_shape
88+
if len(model_batch_shape) > 1:
89+
raise NotImplementedError(
90+
"Ensemble models with more than one ensemble dimension are not "
91+
"yet supported."
92+
)
93+
num_ensemble = model_batch_shape[0]
94+
self.ensemble_indices = torch.randint(
95+
0,
96+
num_ensemble,
97+
(*sample_shape, 1, self.model.num_outputs),
98+
)
5799

58100
@t_batch_mode_transform()
59101
def forward(self, X: Tensor) -> Tensor:
60102
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61103
62104
Args:
63-
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
105+
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64106
65107
Returns:
66-
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67-
evaluations on the posterior sample draws.
108+
A `batch_shape [x m]`-dim tensor of evaluations on the posterior sample
109+
draws, where `m` is the number of outputs of the model.
68110
"""
69-
batch_size = X.shape[-2]
70-
q_dim = -2
111+
objective_values = self._pathwise_forward(X)
112+
# NOTE: can leverage batched L-BFGS computation instead of summing in the future
113+
# sum over batch dim and squeeze num_objectives dim (-1):
114+
acqf_vals = objective_values.sum(-1) # batch_shape
115+
return acqf_vals
71116

117+
def _pathwise_forward(self, X: Tensor) -> Tensor:
118+
batch_size = X.shape[-2]
72119
# batch_shape x q x 1 x d
73120
X = X.unsqueeze(-2)
74-
if self.batch_size is None:
121+
if self.samples is None:
75122
self.batch_size = batch_size
76-
self.redraw()
77-
elif self.batch_size != batch_size:
123+
self.redraw(batch_size=batch_size)
124+
125+
if self.batch_size != batch_size:
78126
raise ValueError(
79127
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
80128
)
129+
# batch_shape x q [x num_ensembles] x 1 x m
130+
posterior_values = self.samples(X)
131+
# batch_shape x q [x num_ensembles] x m
132+
posterior_values = posterior_values.squeeze(-2)
81133

82-
# posterior_values.shape post-squeeze:
83134
# batch_shape x q x m
84-
posterior_values = self.samples(X).squeeze(-2)
85-
# sum over batch dim and squeeze num_objectives dim (-1)
86-
return posterior_values.sum(q_dim).squeeze(-1)
135+
posterior_values = self.select_from_ensemble_models(values=posterior_values)
136+
137+
if self.posterior_transform:
138+
posterior_values = self.posterior_transform.evaluate(posterior_values)
139+
# problem with this currently is that we could still have an `m` dimension,
140+
# ideally that would be packed into a batch dimension instead
141+
# objective removes the `m` dimension:
142+
objective_values = self.objective(posterior_values) # batch_shape x q
143+
return objective_values
144+
145+
def select_from_ensemble_models(self, values: Tensor):
146+
"""Subselecting a value associated with a single sample in the ensemble for each
147+
element of samples that is not associated with an ensemble dimension. NOTE: uses
148+
`self.model` and `is_ensemble` to determine whether or not an ensembling
149+
dimension is present.
150+
151+
Args:
152+
values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
153+
154+
Returns:
155+
A`batch_shape x num_draws x q x m`-dim where each element was chosen
156+
independently randomly from the ensemble dimension.
157+
"""
158+
if not is_ensemble(self.model):
159+
return values
160+
161+
ensemble_dim = -2
162+
# `ensemble_indices` are fixed so that the acquisition function becomes
163+
# deterministic for the same input and can be optimized with LBFGS.
164+
# ensemble indices have shape num_paths x 1 x m
165+
self.ensemble_indices = self.ensemble_indices.to(device=values.device)
166+
index = self.ensemble_indices
167+
input_batch_shape = values.shape[:-3]
168+
index = index.expand(*input_batch_shape, *index.shape)
169+
# samples is batch_shape x q x num_ensemble x m
170+
values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index)
171+
return values_wo_ensemble.squeeze(
172+
ensemble_dim
173+
) # removing the ensemble dimension

botorch/acquisition/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,11 @@ def get_optimal_samples(
575575
else:
576576
sample_transform = None
577577

578-
paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima]))
578+
paths = get_matheron_path_model(
579+
model=model,
580+
sample_shape=torch.Size([num_optima]),
581+
ensemble_as_batch=True,
582+
)
579583
optimal_inputs, optimal_outputs = optimize_posterior_samples(
580584
paths=paths,
581585
bounds=bounds,

botorch/models/deterministic.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ class GenericDeterministicModel(DeterministicModel):
6464
>>> model = GenericDeterministicModel(f)
6565
"""
6666

67-
def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
67+
def __init__(
68+
self,
69+
f: Callable[[Tensor], Tensor],
70+
num_outputs: int = 1,
71+
batch_shape: torch.Size | None = None,
72+
) -> None:
6873
r"""
6974
Args:
7075
f: A callable mapping a `batch_shape x n x d`-dim input tensor `X`
@@ -75,6 +80,12 @@ def __init__(self, f: Callable[[Tensor], Tensor], num_outputs: int = 1) -> None:
7580
super().__init__()
7681
self._f = f
7782
self._num_outputs = num_outputs
83+
self._batch_shape = batch_shape
84+
85+
@property
86+
def batch_shape(self) -> torch.Size | None:
87+
r"""The batch shape of the model."""
88+
return self._batch_shape
7889

7990
def subset_output(self, idcs: list[int]) -> GenericDeterministicModel:
8091
r"""Subset the model along the output dimension.
@@ -100,7 +111,19 @@ def forward(self, X: Tensor) -> Tensor:
100111
Returns:
101112
A `batch_shape x n x m`-dimensional output tensor.
102113
"""
103-
return self._f(X)
114+
Y = self._f(X)
115+
batch_shape = Y.shape[:-2]
116+
# allowing for old behavior of not specifying the batch_shape
117+
if self.batch_shape is not None:
118+
try:
119+
torch.broadcast_shapes(self.batch_shape, batch_shape)
120+
except RuntimeError:
121+
raise ValueError(
122+
"GenericDeterministicModel was initialized with batch_shape="
123+
f"{self.batch_shape=} but the output of f has a batch_shape="
124+
f"{batch_shape=} that is not broadcastable with it."
125+
)
126+
return Y
104127

105128

106129
class AffineDeterministicModel(DeterministicModel):

botorch/sampling/pathwise/paths.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from __future__ import annotations
88

9-
from abc import ABC
9+
from abc import ABC, abstractmethod
1010
from collections.abc import Callable, Iterable, Iterator, Mapping
1111
from typing import Any
1212

@@ -24,6 +24,16 @@
2424
class SamplePath(ABC, TransformedModuleMixin, Module):
2525
r"""Abstract base class for Botorch sample paths."""
2626

27+
@abstractmethod
28+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
29+
"""Sets whether the ensemble dimension is considered as a batch dimension.
30+
31+
Args:
32+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
33+
dimension or not.
34+
"""
35+
pass # pragma: no cover
36+
2737

2838
class PathDict(SamplePath):
2939
r"""A dictionary of SamplePaths."""
@@ -84,6 +94,16 @@ def __getitem__(self, key: str) -> SamplePath:
8494
def __setitem__(self, key: str, val: SamplePath) -> None:
8595
self.paths[key] = val
8696

97+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
98+
"""Sets whether the ensemble dimension is considered as a batch dimension.
99+
100+
Args:
101+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
102+
dimension or not.
103+
"""
104+
for path in self.paths.values():
105+
path.set_ensemble_as_batch(ensemble_as_batch)
106+
87107

88108
class PathList(SamplePath):
89109
r"""A list of SamplePaths."""
@@ -136,6 +156,16 @@ def __getitem__(self, key: int) -> SamplePath:
136156
def __setitem__(self, key: int, val: SamplePath) -> None:
137157
self.paths[key] = val
138158

159+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
160+
"""Sets whether the ensemble dimension is considered as a batch dimension.
161+
162+
Args:
163+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
164+
dimension or not.
165+
"""
166+
for path in self.paths:
167+
path.set_ensemble_as_batch(ensemble_as_batch)
168+
139169

140170
class GeneralizedLinearPath(SamplePath):
141171
r"""A sample path in the form of a generalized linear model."""
@@ -147,6 +177,8 @@ def __init__(
147177
bias_module: Module | None = None,
148178
input_transform: TInputTransform | None = None,
149179
output_transform: TOutputTransform | None = None,
180+
is_ensemble: bool = False,
181+
ensemble_as_batch: bool = False,
150182
):
151183
r"""Initializes a GeneralizedLinearPath instance.
152184
@@ -161,6 +193,11 @@ def __init__(
161193
bias_module: An optional module used to define additive offsets.
162194
input_transform: An optional input transform for the module.
163195
output_transform: An optional output transform for the module.
196+
is_ensemble: Whether the associated model is an ensemble model or not.
197+
ensemble_as_batch: Whether the ensemble dimension is added as a batch
198+
dimension or not. If `True`, the ensemble dimension is treated as a
199+
batch dimension, which allows for the joint optimization of all members
200+
of the ensemble.
164201
"""
165202
super().__init__()
166203
self.feature_map = feature_map
@@ -170,8 +207,23 @@ def __init__(
170207
self.bias_module = bias_module
171208
self.input_transform = input_transform
172209
self.output_transform = output_transform
210+
self.is_ensemble = is_ensemble
211+
self.ensemble_as_batch = ensemble_as_batch
173212

174213
def forward(self, x: Tensor, **kwargs) -> Tensor:
214+
if self.is_ensemble and not self.ensemble_as_batch:
215+
# assuming that the ensembling dimension is added after (n, d), but
216+
# before the other batch dimensions, starting from the left.
217+
x = x.unsqueeze(-3)
175218
feat = self.feature_map(x, **kwargs)
176219
out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1)
177220
return out if self.bias_module is None else out + self.bias_module(x)
221+
222+
def set_ensemble_as_batch(self, ensemble_as_batch: bool) -> None:
223+
"""Sets whether the ensemble dimension is considered as a batch dimension.
224+
225+
Args:
226+
ensemble_as_batch: Whether the ensemble dimension is considered as a batch
227+
dimension or not.
228+
"""
229+
self.ensemble_as_batch = ensemble_as_batch

botorch/sampling/pathwise/posterior_samplers.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(
8787

8888

8989
def get_matheron_path_model(
90-
model: GP, sample_shape: Size | None = None
90+
model: GP, sample_shape: Size | None = None, ensemble_as_batch: bool = False
9191
) -> GenericDeterministicModel:
9292
r"""Generates a deterministic model using a single Matheron path drawn
9393
from the model's posterior.
@@ -108,6 +108,9 @@ def get_matheron_path_model(
108108
"""
109109
sample_shape = Size() if sample_shape is None else sample_shape
110110
path = draw_matheron_paths(model, sample_shape=sample_shape)
111+
# for p in path.paths.values():
112+
# p.ensemble_as_batch = ensemble_as_batch
113+
path.set_ensemble_as_batch(ensemble_as_batch)
111114
num_outputs = model.num_outputs
112115
if isinstance(model, ModelList) and len(model.models) != num_outputs:
113116
raise UnsupportedError("A model-list of multi-output models is not supported.")
@@ -137,7 +140,12 @@ def f(X: Tensor) -> Tensor:
137140
res = path(X.unsqueeze(-3)).transpose(-1, -2)
138141
return res
139142

140-
path_model = GenericDeterministicModel(f=f, num_outputs=num_outputs)
143+
path_model = GenericDeterministicModel(
144+
f=f,
145+
num_outputs=num_outputs,
146+
batch_shape=sample_shape + model.batch_shape,
147+
)
148+
# Do we need the len(sample_shape) > 0?
141149
path_model._is_ensemble = is_ensemble(model) or len(sample_shape) > 0
142150
return path_model
143151

0 commit comments

Comments
 (0)