6
6
7
7
import torch
8
8
from botorch .acquisition .analytic import AcquisitionFunction
9
- from botorch .acquisition .objective import PosteriorTransform
9
+ from botorch .acquisition .objective import (
10
+ IdentityMCObjective ,
11
+ MCAcquisitionObjective ,
12
+ PosteriorTransform ,
13
+ )
14
+ from botorch .exceptions .errors import UnsupportedError
15
+ from botorch .models .deterministic import GenericDeterministicModel
10
16
from botorch .models .model import Model
11
17
from botorch .sampling .pathwise .posterior_samplers import get_matheron_path_model
12
- from botorch .utils .transforms import t_batch_mode_transform
18
+ from botorch .utils .transforms import is_ensemble , t_batch_mode_transform
13
19
from torch import Tensor
14
20
15
21
@@ -32,7 +38,9 @@ class PathwiseThompsonSampling(AcquisitionFunction):
32
38
def __init__ (
33
39
self ,
34
40
model : Model ,
41
+ objective : MCAcquisitionObjective | None = None ,
35
42
posterior_transform : PosteriorTransform | None = None ,
43
+ samples : GenericDeterministicModel | None = None ,
36
44
) -> None :
37
45
r"""Single-outcome TS.
38
46
@@ -41,46 +49,125 @@ def __init__(
41
49
posterior_transform: A PosteriorTransform. If using a multi-output model,
42
50
a PosteriorTransform that transforms the multi-output posterior into a
43
51
single-output posterior is required.
52
+ samples: A GenericDeterministicModel that evaluates a set of posterior
53
+ sample paths.
44
54
"""
45
- if model ._is_fully_bayesian :
46
- raise NotImplementedError (
47
- "PathwiseThompsonSampling is not supported for fully Bayesian models" ,
48
- )
49
55
50
56
super ().__init__ (model = model )
51
- self .batch_size : int | None = None
52
-
53
- def redraw (self ) -> None :
57
+ self .batch_size : int | None = None if samples is None else samples .batch_shape
58
+
59
+ # NOTE: This conditional block is copied from MCAcquisitionFunction, we should
60
+ # consider inherting from it and e.g. getting the X_pending logic as well.
61
+ if objective is None and model .num_outputs != 1 :
62
+ if posterior_transform is None :
63
+ raise UnsupportedError (
64
+ "Must specify an objective or a posterior transform when using "
65
+ "a multi-output model."
66
+ )
67
+ elif not posterior_transform .scalarize :
68
+ raise UnsupportedError (
69
+ "If using a multi-output model without an objective, "
70
+ "posterior_transform must scalarize the output."
71
+ )
72
+ if objective is None :
73
+ objective = IdentityMCObjective ()
74
+ self .objective = objective
75
+ self .posterior_transform = posterior_transform
76
+ self .samples : GenericDeterministicModel | None = samples
77
+
78
+ def redraw (self , batch_size : int ) -> None :
79
+ sample_shape = (batch_size ,)
54
80
self .samples = get_matheron_path_model (
55
- model = self .model , sample_shape = torch .Size ([ self . batch_size ] )
81
+ model = self .model , sample_shape = torch .Size (sample_shape )
56
82
)
83
+ if is_ensemble (self .model ):
84
+ # the ensembling dimension is assumed to be part of the batch shape
85
+ # could add a dedicated proporty to keep track of the ensembling dimension
86
+ # i.e. generalizing num_mcmc_samples in AbstractFullyBayesianSingleTaskGP
87
+ model_batch_shape = self .model .batch_shape
88
+ if len (model_batch_shape ) > 1 :
89
+ raise NotImplementedError (
90
+ "Ensemble models with more than one ensemble dimension are not "
91
+ "yet supported."
92
+ )
93
+ num_ensemble = model_batch_shape [0 ]
94
+ self .ensemble_indices = torch .randint (
95
+ 0 ,
96
+ num_ensemble ,
97
+ (* sample_shape , 1 , self .model .num_outputs ),
98
+ )
57
99
58
100
@t_batch_mode_transform ()
59
101
def forward (self , X : Tensor ) -> Tensor :
60
102
r"""Evaluate the pathwise posterior sample draws on the candidate set X.
61
103
62
104
Args:
63
- X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.
105
+ X: A `batch_shape x q x d`-dim batched tensor of `d`-dim design points.
64
106
65
107
Returns:
66
- A `(b1 x ... bk) x [num_models for fully bayesian ]`-dim tensor of
67
- evaluations on the posterior sample draws .
108
+ A `batch_shape [x m ]`-dim tensor of evaluations on the posterior sample
109
+ draws, where `m` is the number of outputs of the model .
68
110
"""
69
- batch_size = X .shape [- 2 ]
70
- q_dim = - 2
111
+ objective_values = self ._pathwise_forward (X )
112
+ # NOTE: can leverage batched L-BFGS computation instead of summing in the future
113
+ # sum over batch dim and squeeze num_objectives dim (-1):
114
+ acqf_vals = objective_values .sum (- 1 ) # batch_shape
115
+ return acqf_vals
71
116
117
+ def _pathwise_forward (self , X : Tensor ) -> Tensor :
118
+ batch_size = X .shape [- 2 ]
72
119
# batch_shape x q x 1 x d
73
120
X = X .unsqueeze (- 2 )
74
- if self .batch_size is None :
121
+ if self .samples is None :
75
122
self .batch_size = batch_size
76
- self .redraw ()
77
- elif self .batch_size != batch_size :
123
+ self .redraw (batch_size = batch_size )
124
+
125
+ if self .batch_size != batch_size :
78
126
raise ValueError (
79
127
BATCH_SIZE_CHANGE_ERROR .format (self .batch_size , batch_size )
80
128
)
129
+ # batch_shape x q [x num_ensembles] x 1 x m
130
+ posterior_values = self .samples (X )
131
+ # batch_shape x q [x num_ensembles] x m
132
+ posterior_values = posterior_values .squeeze (- 2 )
81
133
82
- # posterior_values.shape post-squeeze:
83
134
# batch_shape x q x m
84
- posterior_values = self .samples (X ).squeeze (- 2 )
85
- # sum over batch dim and squeeze num_objectives dim (-1)
86
- return posterior_values .sum (q_dim ).squeeze (- 1 )
135
+ posterior_values = self .select_from_ensemble_models (values = posterior_values )
136
+
137
+ if self .posterior_transform :
138
+ posterior_values = self .posterior_transform .evaluate (posterior_values )
139
+ # problem with this currently is that we could still have an `m` dimension,
140
+ # ideally that would be packed into a batch dimension instead
141
+ # objective removes the `m` dimension:
142
+ objective_values = self .objective (posterior_values ) # batch_shape x q
143
+ return objective_values
144
+
145
+ def select_from_ensemble_models (self , values : Tensor ):
146
+ """Subselecting a value associated with a single sample in the ensemble for each
147
+ element of samples that is not associated with an ensemble dimension. NOTE: uses
148
+ `self.model` and `is_ensemble` to determine whether or not an ensembling
149
+ dimension is present.
150
+
151
+ Args:
152
+ values: A `batch_shape x num_draws x q [x num_ensemble] x m`-dim Tensor.
153
+
154
+ Returns:
155
+ A`batch_shape x num_draws x q x m`-dim where each element was chosen
156
+ independently randomly from the ensemble dimension.
157
+ """
158
+ if not is_ensemble (self .model ):
159
+ return values
160
+
161
+ ensemble_dim = - 2
162
+ # `ensemble_indices` are fixed so that the acquisition function becomes
163
+ # deterministic for the same input and can be optimized with LBFGS.
164
+ # ensemble indices have shape num_paths x 1 x m
165
+ self .ensemble_indices = self .ensemble_indices .to (device = values .device )
166
+ index = self .ensemble_indices
167
+ input_batch_shape = values .shape [:- 3 ]
168
+ index = index .expand (* input_batch_shape , * index .shape )
169
+ # samples is batch_shape x q x num_ensemble x m
170
+ values_wo_ensemble = torch .gather (values , dim = ensemble_dim , index = index )
171
+ return values_wo_ensemble .squeeze (
172
+ ensemble_dim
173
+ ) # removing the ensemble dimension
0 commit comments