Closed
Description
🐛 Bug
Analytic acquisition functions like ExpectedImprovement and PosteriorStandardDeviation don't work with SingleTaskVariationalGP when trained on a multi-output dataset. I believe this is because the posterior method in ApproximateGPyTorchModel does not accept a posterior_transform argument unlike a SingleTaskGP.
To reproduce
** Code snippet to reproduce **
import math
import torch
from botorch.models import SingleTaskVariationalGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import Standardize, Normalize
from gpytorch.mlls import VariationalELBO
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf
from botorch.acquisition.objective import ScalarizedPosteriorTransform
seed = 42
torch.random.manual_seed(seed)
n_train = 50
X = torch.linspace(0, 0.5, n_train)
y = torch.stack(
[
torch.sin(X * (2 * math.pi)) + torch.randn(X.size()) * 0.01,
torch.cos(X * (2 * math.pi)) + torch.randn(X.size()) * 0.01,
],
-1,
)
input_transform = Normalize(d=1)
outcome_transform = Standardize(m=2)
gp = SingleTaskVariationalGP(
X.unsqueeze(-1),
y,
outcome_transform=outcome_transform,
input_transform=input_transform,
)
mll = VariationalELBO(gp.likelihood, gp.model, num_data=X.shape[0])
fit_gpytorch_mll(mll)
EI = ExpectedImprovement(
gp,
best_f=y.max(),
posterior_transform=ScalarizedPosteriorTransform(weights=torch.tensor([1.0, 0.5])),
)
candidates, acq = optimize_acqf(
EI,
bounds=torch.tensor([[0.0], [1.0]]),
q=1,
num_restarts=5,
raw_samples=20,
)
** Stack trace/error message **
{
"name": "AssertionError",
"message": "Expected the output shape to match either the t-batch shape of X, or the `model.batch_shape` in the case of acquisition functions using batch models; but got output with shape torch.Size([20, 2]) for X with shape torch.Size([20, 1, 1]).",
"stack": "---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[12], line 42
34 fit_gpytorch_mll(mll)
36 EI = ExpectedImprovement(
37 gp,
38 best_f=y.max(),
39 posterior_transform=ScalarizedPosteriorTransform(weights=torch.tensor([1.0, 0.5])),
40 )
---> 42 candidates, acq = optimize_acqf(
43 EI,
44 bounds=torch.tensor([[0.0], [1.0]]),
45 q=1,
46 num_restarts=5,
47 raw_samples=20,
48 )
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:543, in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, gen_candidates, sequential, ic_generator, timeout_sec, return_full_tree, retry_on_optimization_warning, **ic_gen_kwargs)
520 gen_candidates = gen_candidates_scipy
521 opt_acqf_inputs = OptimizeAcqfInputs(
522 acq_function=acq_function,
523 bounds=bounds,
(...)
541 ic_gen_kwargs=ic_gen_kwargs,
542 )
--> 543 return _optimize_acqf(opt_acqf_inputs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:564, in _optimize_acqf(opt_inputs)
561 return _optimize_acqf_sequential_q(opt_inputs=opt_inputs)
563 # Batch optimization (including the case q=1)
--> 564 return _optimize_acqf_batch(opt_inputs=opt_inputs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:255, in _optimize_acqf_batch(opt_inputs)
252 batch_initial_conditions = opt_inputs.batch_initial_conditions
253 else:
254 # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
--> 255 batch_initial_conditions = opt_inputs.get_ic_generator()(
256 acq_function=opt_inputs.acq_function,
257 bounds=opt_inputs.bounds,
258 q=opt_inputs.q,
259 num_restarts=opt_inputs.num_restarts,
260 raw_samples=opt_inputs.raw_samples,
261 fixed_features=opt_inputs.fixed_features,
262 options=options,
263 inequality_constraints=opt_inputs.inequality_constraints,
264 equality_constraints=opt_inputs.equality_constraints,
265 **opt_inputs.ic_gen_kwargs,
266 )
268 batch_limit: int = options.get(
269 \"batch_limit\",
270 (
(...)
274 ),
275 )
277 def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/initializers.py:418, in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, fixed_features, options, inequality_constraints, equality_constraints, generator, fixed_X_fantasies)
416 while start_idx < X_rnd.shape[0]:
417 end_idx = min(start_idx + batch_limit, X_rnd.shape[0])
--> 418 Y_rnd_curr = acq_function(
419 X_rnd[start_idx:end_idx].to(device=device)
420 ).cpu()
421 Y_rnd_list.append(Y_rnd_curr)
422 start_idx += batch_limit
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
-> 1553 return self._call_impl(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don't have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None
File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/utils/transforms.py:299, in t_batch_mode_transform.<locals>.decorator.<locals>.decorated(acqf, X, *args, **kwargs)
291 output = (
292 output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1)
293 )
294 if assert_output_shape and not _verify_output_shape(
295 acqf=acqf,
296 X=X,
297 output=output,
298 ):
--> 299 raise AssertionError(
300 \"Expected the output shape to match either the t-batch shape of \"
301 \"X, or the `model.batch_shape` in the case of acquisition \"
302 \"functions using batch models; but got output with shape \"
303 f\"{output.shape} for X with shape {X.shape}.\"
304 )
305 return output
AssertionError: Expected the output shape to match either the t-batch shape of X, or the `model.batch_shape` in the case of acquisition functions using batch models; but got output with shape torch.Size([20, 2]) for X with shape torch.Size([20, 1, 1])."
}
Expected Behavior
Should have returned a single candidate point along with the acquisition function value.
System information
Please complete the following information:
- BoTorch version: 0.11.3
- GPyTorch version: 1.12
- Torch version: 2.4.1
- OS: macOS Sonoma 14.5
Additional context
I was able to generate candidate points by adding a posterior_transform in the posterior method (used ScalarizedPosteriorTransform to combine the outputs and form a single output posterior).