Skip to content

Commit cf42bf3

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Pathwise Thomspon sampling for ensemble models
Summary: This commit adds support for pathwise Thompson sampling for ensemble models, including fully Bayesian SAAS models. Differential Revision: D75990595
1 parent 13444be commit cf42bf3

File tree

6 files changed

+111
-26
lines changed

6 files changed

+111
-26
lines changed

botorch/acquisition/thompson_sampling.py

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from botorch.acquisition.objective import PosteriorTransform
1010
from botorch.models.model import Model
1111
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
12-
from botorch.utils.transforms import t_batch_mode_transform
12+
from botorch.utils.transforms import is_ensemble, t_batch_mode_transform
1313
from torch import Tensor
1414

1515

@@ -42,45 +42,88 @@ def __init__(
4242
a PosteriorTransform that transforms the multi-output posterior into a
4343
single-output posterior is required.
4444
"""
45-
if model._is_fully_bayesian:
46-
raise NotImplementedError(
47-
"PathwiseThompsonSampling is not supported for fully Bayesian models",
48-
)
4945

5046
super().__init__(model=model)
5147
self.batch_size: int | None = None
5248

53-
def redraw(self) -> None:
49+
def redraw(self, batch_size: int) -> None:
50+
sample_shape = (batch_size,)
5451
self.samples = get_matheron_path_model(
55-
model=self.model, sample_shape=torch.Size([self.batch_size])
52+
model=self.model, sample_shape=torch.Size(sample_shape)
5653
)
54+
if is_ensemble(self.model):
55+
# the ensembling dimension is assumed to be part of the batch shape
56+
# could add a dedicated proporty to keep track of the ensembling dimension
57+
# i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP
58+
model_batch_shape = self.model.batch_shape
59+
if len(model_batch_shape) > 1:
60+
raise NotImplementedError(
61+
"Ensemble models with more than one ensemble dimension are not "
62+
"yet supported."
63+
)
64+
num_ensemble = model_batch_shape[0]
65+
self.ensemble_indices = torch.randint(
66+
0, num_ensemble, (*sample_shape, 1, self.model.num_outputs)
67+
)
5768

5869
@t_batch_mode_transform()
5970
def forward(self, X: Tensor) -> Tensor:
6071
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
6172
6273
Args:
63-
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
74+
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
6475
6576
Returns:
66-
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67-
evaluations on the posterior sample draws.
77+
A `batch_shape [x m]`-dim tensor of evaluations on the posterior sample
78+
draws, where `m` is the number of outputs of the model.
6879
"""
6980
batch_size = X.shape[-2]
7081
q_dim = -2
71-
7282
# batch_shape x q x 1 x d
7383
X = X.unsqueeze(-2)
7484
if self.batch_size is None:
7585
self.batch_size = batch_size
76-
self.redraw()
86+
self.redraw(batch_size=batch_size)
7787
elif self.batch_size != batch_size:
7888
raise ValueError(
7989
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
8090
)
81-
82-
# posterior_values.shape post-squeeze:
91+
# batch_shape x q [x num_ensembles] x 1 x m
92+
posterior_values = self.samples(X)
93+
# batch_shape x q [x num_ensembles] x m
94+
posterior_values = posterior_values.squeeze(-2)
8395
# batch_shape x q x m
84-
posterior_values = self.samples(X).squeeze(-2)
85-
# sum over batch dim and squeeze num_objectives dim (-1)
86-
return posterior_values.sum(q_dim).squeeze(-1)
96+
posterior_values = self.select_from_ensemble_models(values=posterior_values)
97+
# NOTE: can leverage batched L-BFGS computation instead of summing in the future
98+
# sum over batch dim and squeeze num_objectives dim (-1): batch_shape [x m]
99+
acqf_vals = posterior_values.sum(q_dim).squeeze(-1)
100+
return acqf_vals
101+
102+
def select_from_ensemble_models(self, values: Tensor):
103+
"""Subselecting a value associated with a single sample in the ensemble for each
104+
element of samples that is not associated with an ensemble dimension. NOTE: uses
105+
`self.model` and `is_ensemble` to determine whether or not an ensembling dimension
106+
is present.
107+
108+
Args:
109+
values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
110+
111+
Returns:
112+
A`batch_shape x num_draws x q x m`-dim where each element was chosen
113+
independently randomly from the ensemble dimension.
114+
"""
115+
if not is_ensemble(self.model):
116+
return values
117+
118+
ensemble_dim = -2
119+
# `ensemble_indices` are fixed so that the acquisition function becomes
120+
# deterministic for the same input and can be optimized with LBFGS.
121+
# ensemble indices have shape num_paths x 1 x m
122+
index = self.ensemble_indices
123+
input_batch_shape = values.shape[:-3]
124+
index = index.expand(*input_batch_shape, *index.shape)
125+
# samples is batch_shape x q x num_ensemble x m
126+
values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index)
127+
return values_wo_ensemble.squeeze(
128+
ensemble_dim
129+
) # removing the ensemble dimension

botorch/sampling/pathwise/paths.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
bias_module: Module | None = None,
148148
input_transform: TInputTransform | None = None,
149149
output_transform: TOutputTransform | None = None,
150+
is_ensemble: bool = False,
150151
):
151152
r"""Initializes a GeneralizedLinearPath instance.
152153
@@ -161,6 +162,7 @@ def __init__(
161162
bias_module: An optional module used to define additive offsets.
162163
input_transform: An optional input transform for the module.
163164
output_transform: An optional output transform for the module.
165+
is_ensemble: Whether the associated model is an ensemble model or not.
164166
"""
165167
super().__init__()
166168
self.feature_map = feature_map
@@ -170,8 +172,13 @@ def __init__(
170172
self.bias_module = bias_module
171173
self.input_transform = input_transform
172174
self.output_transform = output_transform
175+
self.is_ensemble = is_ensemble
173176

174177
def forward(self, x: Tensor, **kwargs) -> Tensor:
178+
if self.is_ensemble:
179+
# assuming that the ensembling dimension is added after (n, d), but
180+
# before the other batch dimensions, starting from the left.
181+
x = x.unsqueeze(-3)
175182
feat = self.feature_map(x, **kwargs)
176183
out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1)
177184
return out if self.bias_module is None else out + self.bias_module(x)

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from botorch.utils.dispatcher import Dispatcher
2626
from botorch.utils.sampling import draw_sobol_normal_samples
27+
from botorch.utils.transforms import is_ensemble
2728
from gpytorch.kernels import Kernel
2829
from gpytorch.models import ApproximateGP, ExactGP, GP
2930
from gpytorch.variational import _VariationalStrategy
@@ -61,6 +62,7 @@ def _draw_kernel_feature_paths_fallback(
6162
input_transform: TInputTransform | None = None,
6263
output_transform: TOutputTransform | None = None,
6364
weight_generator: Callable[[Size], Tensor] | None = None,
65+
is_ensemble: bool = False,
6466
) -> GeneralizedLinearPath:
6567
# Generate a kernel feature map
6668
feature_map = map_generator(
@@ -89,6 +91,7 @@ def _draw_kernel_feature_paths_fallback(
8991
bias_module=mean_module,
9092
input_transform=input_transform,
9193
output_transform=output_transform,
94+
is_ensemble=is_ensemble,
9295
)
9396

9497

@@ -103,6 +106,7 @@ def _draw_kernel_feature_paths_ExactGP(
103106
covar_module=model.covar_module,
104107
input_transform=get_input_transform(model),
105108
output_transform=get_output_transform(model),
109+
is_ensemble=is_ensemble(model),
106110
**kwargs,
107111
)
108112

@@ -150,5 +154,6 @@ def _draw_kernel_feature_paths_ApproximateGP_fallback(
150154
num_inputs=num_inputs,
151155
mean_module=model.mean_module,
152156
covar_module=model.covar_module,
157+
is_ensemble=is_ensemble(model),
153158
**kwargs,
154159
)

botorch/sampling/pathwise/update_strategies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any
1414

1515
import torch
16+
1617
from botorch.models.approximate_gp import ApproximateGPyTorchModel
1718
from botorch.models.transforms.input import InputTransform
1819
from botorch.sampling.pathwise.features import KernelEvaluationMap
@@ -24,6 +25,7 @@
2425
TInputTransform,
2526
)
2627
from botorch.utils.dispatcher import Dispatcher
28+
from botorch.utils.transforms import is_ensemble
2729
from botorch.utils.types import DEFAULT
2830
from gpytorch.kernels.kernel import Kernel
2931
from gpytorch.likelihoods import _GaussianLikelihoodBase, Likelihood
@@ -79,6 +81,7 @@ def _gaussian_update_exact(
7981
noise_covariance: Tensor | LinearOperator | None = None,
8082
scale_tril: Tensor | LinearOperator | None = None,
8183
input_transform: TInputTransform | None = None,
84+
is_ensemble: bool = False,
8285
) -> GeneralizedLinearPath:
8386
# Prepare Cholesky factor of `Cov(y, y)` and noise sample values as needed
8487
if isinstance(noise_covariance, (NoneType, ZeroLinearOperator)):
@@ -103,7 +106,9 @@ def _gaussian_update_exact(
103106
points=points,
104107
input_transform=input_transform,
105108
)
106-
return GeneralizedLinearPath(feature_map=feature_map, weight=weight.squeeze(-1))
109+
return GeneralizedLinearPath(
110+
feature_map=feature_map, weight=weight.squeeze(-1), is_ensemble=is_ensemble
111+
)
107112

108113

109114
@GaussianUpdate.register(ExactGP, _GaussianLikelihoodBase)
@@ -134,6 +139,7 @@ def _gaussian_update_ExactGP(
134139
noise_covariance=noise_covariance,
135140
scale_tril=scale_tril,
136141
input_transform=get_input_transform(model),
142+
is_ensemble=is_ensemble(model),
137143
)
138144

139145

@@ -194,4 +200,5 @@ def _gaussian_update_ApproximateGP_VariationalStrategy(
194200
sample_values=sample_values,
195201
scale_tril=L,
196202
input_transform=input_transform,
203+
is_ensemble=is_ensemble(model),
197204
)

botorch/utils/transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def decorated(
293293
f"Expected X to be `batch_shape x q={expected_q} x d`, but"
294294
f" got X with shape {X.shape}."
295295
)
296+
X_original_shape = X.shape
296297
# add t-batch dim
297298
X = X if X.dim() > 2 else X.unsqueeze(0)
298299
output = method(acqf, X, *args, **kwargs)
@@ -306,6 +307,8 @@ def decorated(
306307
"X, or the `model.batch_shape` in the case of acquisition "
307308
"functions using batch models; but got output with shape "
308309
f"{output.shape} for X with shape {X.shape}."
310+
f"The original X shape was {X_original_shape} before the "
311+
"t_batch_mode_transform decorator modified it."
309312
)
310313
return output
311314

test/acquisition/test_thompson_sampling.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from itertools import product
88

9+
from unittest import mock
10+
from unittest.mock import PropertyMock
11+
912
import torch
1013
from botorch.acquisition.thompson_sampling import PathwiseThompsonSampling
1114
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
@@ -30,7 +33,7 @@ def get_fully_bayesian_model(
3033
train_Y,
3134
num_models,
3235
**tkwargs,
33-
):
36+
) -> SaasFullyBayesianSingleTaskGP:
3437
model = SaasFullyBayesianSingleTaskGP(
3538
train_X=train_X,
3639
train_Y=train_Y,
@@ -59,7 +62,7 @@ def _test_thompson_sampling_base(self, model: Model):
5962

6063
acq_pass1 = acq(test_X)
6164
self.assertAllClose(acq_pass1, acq(test_X))
62-
acq.redraw()
65+
acq.redraw(batch_size=acq.batch_size)
6366
acq_pass2 = acq(test_X)
6467
self.assertFalse(torch.allclose(acq_pass1, acq_pass2))
6568

@@ -109,10 +112,27 @@ def test_thompson_sampling_fully_bayesian(self):
109112
tkwargs = {"device": self.device, "dtype": torch.float64}
110113
train_X = torch.rand(4, input_dim, **tkwargs)
111114
train_Y = 10 * torch.rand(4, num_objectives, **tkwargs)
112-
113115
fb_model = get_fully_bayesian_model(train_X, train_Y, num_models=3, **tkwargs)
114-
with self.assertRaisesRegex(
115-
NotImplementedError,
116-
"PathwiseThompsonSampling is not supported for fully Bayesian models",
117-
):
118-
PathwiseThompsonSampling(model=fb_model)
116+
acqf = PathwiseThompsonSampling(model=fb_model)
117+
acqf_vals = acqf(train_X)
118+
119+
acqf_vals_2 = acqf(train_X)
120+
121+
self.assertAllClose(acqf_vals, acqf_vals_2)
122+
123+
batch_shape = (2, 5)
124+
test_X = torch.randn(*batch_shape, *train_X.shape)
125+
batched_output = acqf(test_X)
126+
self.assertEqual(batched_output.shape, batch_shape)
127+
batched_output_2 = acqf(test_X)
128+
self.assertAllClose(batched_output, batched_output_2)
129+
130+
with mock.patch.object(
131+
type(acqf.model), "batch_shape", new_callable=PropertyMock
132+
) as mock_batch_shape:
133+
mock_batch_shape.return_value = (2, 3)
134+
with self.assertRaisesRegex(
135+
NotImplementedError,
136+
"Ensemble models with more than one ensemble dimension",
137+
):
138+
acqf.redraw(batch_size=2)

0 commit comments

Comments
 (0)