Skip to content

Commit b034299

Browse files
SebastianAmentfacebook-github-bot
authored andcommitted
Pathwise Thomspon sampling for ensemble models (#2877)
Summary: Pull Request resolved: #2877 This commit adds support for pathwise Thompson sampling for ensemble models, including fully Bayesian SAAS models. Differential Revision: D75990595
1 parent 13444be commit b034299

File tree

7 files changed

+192
-81
lines changed

7 files changed

+192
-81
lines changed

botorch/acquisition/thompson_sampling.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from botorch.acquisition.objective import PosteriorTransform
1010
from botorch.models.model import Model
1111
from botorch.sampling.pathwise.posterior_samplers import get_matheron_path_model
12-
from botorch.utils.transforms import t_batch_mode_transform
12+
from botorch.utils.transforms import is_ensemble, t_batch_mode_transform
1313
from torch import Tensor
1414

1515

@@ -42,45 +42,91 @@ def __init__(
4242
a PosteriorTransform that transforms the multi-output posterior into a
4343
single-output posterior is required.
4444
"""
45-
if model._is_fully_bayesian:
46-
raise NotImplementedError(
47-
"PathwiseThompsonSampling is not supported for fully Bayesian models",
48-
)
4945

5046
super().__init__(model=model)
5147
self.batch_size: int | None = None
5248

53-
def redraw(self) -> None:
49+
def redraw(self, batch_size: int) -> None:
50+
sample_shape = (batch_size,)
5451
self.samples = get_matheron_path_model(
55-
model=self.model, sample_shape=torch.Size([self.batch_size])
52+
model=self.model, sample_shape=torch.Size(sample_shape)
5653
)
54+
if is_ensemble(self.model):
55+
# the ensembling dimension is assumed to be part of the batch shape
56+
# could add a dedicated proporty to keep track of the ensembling dimension
57+
# i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP
58+
model_batch_shape = self.model.batch_shape
59+
if len(model_batch_shape) > 1:
60+
raise NotImplementedError(
61+
"Ensemble models with more than one ensemble dimension are not "
62+
"yet supported."
63+
)
64+
num_ensemble = model_batch_shape[0]
65+
self.ensemble_indices = torch.randint(
66+
0,
67+
num_ensemble,
68+
(*sample_shape, 1, self.model.num_outputs),
69+
)
5770

5871
@t_batch_mode_transform()
5972
def forward(self, X: Tensor) -> Tensor:
6073
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
6174
6275
Args:
63-
X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
76+
X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
6477
6578
Returns:
66-
A `(b1 x ... bk) x [num_models for fully bayesian]`-dim tensor of
67-
evaluations on the posterior sample draws.
79+
A `batch_shape [x m]`-dim tensor of evaluations on the posterior sample
80+
draws, where `m` is the number of outputs of the model.
6881
"""
6982
batch_size = X.shape[-2]
7083
q_dim = -2
71-
7284
# batch_shape x q x 1 x d
7385
X = X.unsqueeze(-2)
7486
if self.batch_size is None:
7587
self.batch_size = batch_size
76-
self.redraw()
88+
self.redraw(batch_size=batch_size)
7789
elif self.batch_size != batch_size:
7890
raise ValueError(
7991
BATCH_SIZE_CHANGE_ERROR.format(self.batch_size, batch_size)
8092
)
81-
82-
# posterior_values.shape post-squeeze:
93+
# batch_shape x q [x num_ensembles] x 1 x m
94+
posterior_values = self.samples(X)
95+
# batch_shape x q [x num_ensembles] x m
96+
posterior_values = posterior_values.squeeze(-2)
8397
# batch_shape x q x m
84-
posterior_values = self.samples(X).squeeze(-2)
85-
# sum over batch dim and squeeze num_objectives dim (-1)
86-
return posterior_values.sum(q_dim).squeeze(-1)
98+
posterior_values = self.select_from_ensemble_models(values=posterior_values)
99+
# NOTE: can leverage batched L-BFGS computation instead of summing in the future
100+
# sum over batch dim and squeeze num_objectives dim (-1): batch_shape [x m]
101+
acqf_vals = posterior_values.sum(q_dim).squeeze(-1)
102+
return acqf_vals
103+
104+
def select_from_ensemble_models(self, values: Tensor):
105+
"""Subselecting a value associated with a single sample in the ensemble for each
106+
element of samples that is not associated with an ensemble dimension. NOTE: uses
107+
`self.model` and `is_ensemble` to determine whether or not an ensembling
108+
dimension is present.
109+
110+
Args:
111+
values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
112+
113+
Returns:
114+
A`batch_shape x num_draws x q x m`-dim where each element was chosen
115+
independently randomly from the ensemble dimension.
116+
"""
117+
if not is_ensemble(self.model):
118+
return values
119+
120+
ensemble_dim = -2
121+
# `ensemble_indices` are fixed so that the acquisition function becomes
122+
# deterministic for the same input and can be optimized with LBFGS.
123+
# ensemble indices have shape num_paths x 1 x m
124+
self.ensemble_indices = self.ensemble_indices.to(device=values.device)
125+
index = self.ensemble_indices
126+
input_batch_shape = values.shape[:-3]
127+
index = index.expand(*input_batch_shape, *index.shape)
128+
# samples is batch_shape x q x num_ensemble x m
129+
values_wo_ensemble = torch.gather(values, dim=ensemble_dim, index=index)
130+
return values_wo_ensemble.squeeze(
131+
ensemble_dim
132+
) # removing the ensemble dimension

botorch/sampling/pathwise/paths.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
bias_module: Module | None = None,
148148
input_transform: TInputTransform | None = None,
149149
output_transform: TOutputTransform | None = None,
150+
is_ensemble: bool = False,
150151
):
151152
r"""Initializes a GeneralizedLinearPath instance.
152153
@@ -161,6 +162,7 @@ def __init__(
161162
bias_module: An optional module used to define additive offsets.
162163
input_transform: An optional input transform for the module.
163164
output_transform: An optional output transform for the module.
165+
is_ensemble: Whether the associated model is an ensemble model or not.
164166
"""
165167
super().__init__()
166168
self.feature_map = feature_map
@@ -170,8 +172,13 @@ def __init__(
170172
self.bias_module = bias_module
171173
self.input_transform = input_transform
172174
self.output_transform = output_transform
175+
self.is_ensemble = is_ensemble
173176

174177
def forward(self, x: Tensor, **kwargs) -> Tensor:
178+
if self.is_ensemble:
179+
# assuming that the ensembling dimension is added after (n, d), but
180+
# before the other batch dimensions, starting from the left.
181+
x = x.unsqueeze(-3)
175182
feat = self.feature_map(x, **kwargs)
176183
out = (feat @ self.weight.unsqueeze(-1)).squeeze(-1)
177184
return out if self.bias_module is None else out + self.bias_module(x)

botorch/sampling/pathwise/prior_samplers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from botorch.utils.dispatcher import Dispatcher
2626
from botorch.utils.sampling import draw_sobol_normal_samples
27+
from botorch.utils.transforms import is_ensemble
2728
from gpytorch.kernels import Kernel
2829
from gpytorch.models import ApproximateGP, ExactGP, GP
2930
from gpytorch.variational import _VariationalStrategy
@@ -61,6 +62,7 @@ def _draw_kernel_feature_paths_fallback(
6162
input_transform: TInputTransform | None = None,
6263
output_transform: TOutputTransform | None = None,
6364
weight_generator: Callable[[Size], Tensor] | None = None,
65+
is_ensemble: bool = False,
6466
) -> GeneralizedLinearPath:
6567
# Generate a kernel feature map
6668
feature_map = map_generator(
@@ -89,6 +91,7 @@ def _draw_kernel_feature_paths_fallback(
8991
bias_module=mean_module,
9092
input_transform=input_transform,
9193
output_transform=output_transform,
94+
is_ensemble=is_ensemble,
9295
)
9396

9497

@@ -103,6 +106,7 @@ def _draw_kernel_feature_paths_ExactGP(
103106
covar_module=model.covar_module,
104107
input_transform=get_input_transform(model),
105108
output_transform=get_output_transform(model),
109+
is_ensemble=is_ensemble(model),
106110
**kwargs,
107111
)
108112

@@ -150,5 +154,6 @@ def _draw_kernel_feature_paths_ApproximateGP_fallback(
150154
num_inputs=num_inputs,
151155
mean_module=model.mean_module,
152156
covar_module=model.covar_module,
157+
is_ensemble=is_ensemble(model),
153158
**kwargs,
154159
)

botorch/sampling/pathwise/update_strategies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any
1414

1515
import torch
16+
1617
from botorch.models.approximate_gp import ApproximateGPyTorchModel
1718
from botorch.models.transforms.input import InputTransform
1819
from botorch.sampling.pathwise.features import KernelEvaluationMap
@@ -24,6 +25,7 @@
2425
TInputTransform,
2526
)
2627
from botorch.utils.dispatcher import Dispatcher
28+
from botorch.utils.transforms import is_ensemble
2729
from botorch.utils.types import DEFAULT
2830
from gpytorch.kernels.kernel import Kernel
2931
from gpytorch.likelihoods import _GaussianLikelihoodBase, Likelihood
@@ -79,6 +81,7 @@ def _gaussian_update_exact(
7981
noise_covariance: Tensor | LinearOperator | None = None,
8082
scale_tril: Tensor | LinearOperator | None = None,
8183
input_transform: TInputTransform | None = None,
84+
is_ensemble: bool = False,
8285
) -> GeneralizedLinearPath:
8386
# Prepare Cholesky factor of `Cov(y, y)` and noise sample values as needed
8487
if isinstance(noise_covariance, (NoneType, ZeroLinearOperator)):
@@ -103,7 +106,9 @@ def _gaussian_update_exact(
103106
points=points,
104107
input_transform=input_transform,
105108
)
106-
return GeneralizedLinearPath(feature_map=feature_map, weight=weight.squeeze(-1))
109+
return GeneralizedLinearPath(
110+
feature_map=feature_map, weight=weight.squeeze(-1), is_ensemble=is_ensemble
111+
)
107112

108113

109114
@GaussianUpdate.register(ExactGP, _GaussianLikelihoodBase)
@@ -134,6 +139,7 @@ def _gaussian_update_ExactGP(
134139
noise_covariance=noise_covariance,
135140
scale_tril=scale_tril,
136141
input_transform=get_input_transform(model),
142+
is_ensemble=is_ensemble(model),
137143
)
138144

139145

@@ -194,4 +200,5 @@ def _gaussian_update_ApproximateGP_VariationalStrategy(
194200
sample_values=sample_values,
195201
scale_tril=L,
196202
input_transform=input_transform,
203+
is_ensemble=is_ensemble(model),
197204
)

botorch/utils/test_helpers.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,6 @@
3939
from torch.nn.functional import pad
4040

4141

42-
def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
43-
mcmc_samples = {
44-
"lengthscale": 1 + torch.rand(num_samples, 1, dim, **tkwargs),
45-
"outputscale": 1 + torch.rand(num_samples, **tkwargs),
46-
"mean": torch.randn(num_samples, **tkwargs),
47-
}
48-
if infer_noise:
49-
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs)
50-
mcmc_samples["lengthscale"] = mcmc_samples["lengthscale"]
51-
52-
return mcmc_samples
53-
54-
5542
def get_model(
5643
train_X: Tensor,
5744
train_Y: Tensor,
@@ -93,8 +80,8 @@ def get_fully_bayesian_model(
9380
train_X: Tensor,
9481
train_Y: Tensor,
9582
num_models: int,
96-
standardize_model: bool,
97-
infer_noise: bool,
83+
standardize_model: bool = False,
84+
infer_noise: bool = True,
9885
**tkwargs: Any,
9986
) -> SaasFullyBayesianSingleTaskGP:
10087
num_objectives = train_Y.shape[-1]
@@ -122,6 +109,20 @@ def get_fully_bayesian_model(
122109
return model
123110

124111

112+
def _get_mcmc_samples(
113+
num_samples: int, dim: int, infer_noise: bool, **tkwargs
114+
) -> dict[str, Tensor]:
115+
mcmc_samples = {
116+
"lengthscale": 1 + torch.rand(num_samples, 1, dim, **tkwargs),
117+
"outputscale": 1 + torch.rand(num_samples, **tkwargs),
118+
"mean": torch.randn(num_samples, **tkwargs),
119+
}
120+
if infer_noise:
121+
mcmc_samples["noise"] = torch.rand(num_samples, 1, **tkwargs)
122+
123+
return mcmc_samples
124+
125+
125126
def get_fully_bayesian_model_list(
126127
train_X: Tensor,
127128
train_Y: Tensor,

0 commit comments

Comments
 (0)