Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]Analytic Acquisition Functions not working with SingleTaskVariationalGP for multi-output data #2530

Closed
SaiAakash opened this issue Sep 11, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@SaiAakash
Copy link
Contributor

🐛 Bug

Analytic acquisition functions like ExpectedImprovement and PosteriorStandardDeviation don't work with SingleTaskVariationalGP when trained on a multi-output dataset. I believe this is because the posterior method in ApproximateGPyTorchModel does not accept a posterior_transform argument unlike a SingleTaskGP.

To reproduce

** Code snippet to reproduce **

import math
import torch
from botorch.models import SingleTaskVariationalGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import Standardize, Normalize
from gpytorch.mlls import VariationalELBO
from botorch.acquisition import ExpectedImprovement
from botorch.optim import optimize_acqf
from botorch.acquisition.objective import ScalarizedPosteriorTransform

seed = 42
torch.random.manual_seed(seed)

n_train = 50
X = torch.linspace(0, 0.5, n_train)
y = torch.stack(
    [
        torch.sin(X * (2 * math.pi)) + torch.randn(X.size()) * 0.01,
        torch.cos(X * (2 * math.pi)) + torch.randn(X.size()) * 0.01,
    ],
    -1,
)

input_transform = Normalize(d=1)
outcome_transform = Standardize(m=2)
gp = SingleTaskVariationalGP(
    X.unsqueeze(-1),
    y,
    outcome_transform=outcome_transform,
    input_transform=input_transform,
)

mll = VariationalELBO(gp.likelihood, gp.model, num_data=X.shape[0])
fit_gpytorch_mll(mll)

EI = ExpectedImprovement(
    gp,
    best_f=y.max(),
    posterior_transform=ScalarizedPosteriorTransform(weights=torch.tensor([1.0, 0.5])),
)

candidates, acq = optimize_acqf(
    EI,
    bounds=torch.tensor([[0.0], [1.0]]),
    q=1,
    num_restarts=5,
    raw_samples=20,
)

** Stack trace/error message **

{
	"name": "AssertionError",
	"message": "Expected the output shape to match either the t-batch shape of X, or the `model.batch_shape` in the case of acquisition functions using batch models; but got output with shape torch.Size([20, 2]) for X with shape torch.Size([20, 1, 1]).",
	"stack": "---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[12], line 42
     34 fit_gpytorch_mll(mll)
     36 EI = ExpectedImprovement(
     37     gp,
     38     best_f=y.max(),
     39     posterior_transform=ScalarizedPosteriorTransform(weights=torch.tensor([1.0, 0.5])),
     40 )
---> 42 candidates, acq = optimize_acqf(
     43     EI,
     44     bounds=torch.tensor([[0.0], [1.0]]),
     45     q=1,
     46     num_restarts=5,
     47     raw_samples=20,
     48 )

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:543, in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, gen_candidates, sequential, ic_generator, timeout_sec, return_full_tree, retry_on_optimization_warning, **ic_gen_kwargs)
    520     gen_candidates = gen_candidates_scipy
    521 opt_acqf_inputs = OptimizeAcqfInputs(
    522     acq_function=acq_function,
    523     bounds=bounds,
   (...)
    541     ic_gen_kwargs=ic_gen_kwargs,
    542 )
--> 543 return _optimize_acqf(opt_acqf_inputs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:564, in _optimize_acqf(opt_inputs)
    561     return _optimize_acqf_sequential_q(opt_inputs=opt_inputs)
    563 # Batch optimization (including the case q=1)
--> 564 return _optimize_acqf_batch(opt_inputs=opt_inputs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/optimize.py:255, in _optimize_acqf_batch(opt_inputs)
    252     batch_initial_conditions = opt_inputs.batch_initial_conditions
    253 else:
    254     # pyre-ignore[28]: Unexpected keyword argument `acq_function` to anonymous call.
--> 255     batch_initial_conditions = opt_inputs.get_ic_generator()(
    256         acq_function=opt_inputs.acq_function,
    257         bounds=opt_inputs.bounds,
    258         q=opt_inputs.q,
    259         num_restarts=opt_inputs.num_restarts,
    260         raw_samples=opt_inputs.raw_samples,
    261         fixed_features=opt_inputs.fixed_features,
    262         options=options,
    263         inequality_constraints=opt_inputs.inequality_constraints,
    264         equality_constraints=opt_inputs.equality_constraints,
    265         **opt_inputs.ic_gen_kwargs,
    266     )
    268 batch_limit: int = options.get(
    269     \"batch_limit\",
    270     (
   (...)
    274     ),
    275 )
    277 def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/optim/initializers.py:418, in gen_batch_initial_conditions(acq_function, bounds, q, num_restarts, raw_samples, fixed_features, options, inequality_constraints, equality_constraints, generator, fixed_X_fantasies)
    416 while start_idx < X_rnd.shape[0]:
    417     end_idx = min(start_idx + batch_limit, X_rnd.shape[0])
--> 418     Y_rnd_curr = acq_function(
    419         X_rnd[start_idx:end_idx].to(device=device)
    420     ).cpu()
    421     Y_rnd_list.append(Y_rnd_curr)
    422     start_idx += batch_limit

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/utils/transforms.py:299, in t_batch_mode_transform.<locals>.decorator.<locals>.decorated(acqf, X, *args, **kwargs)
    291     output = (
    292         output.mean(dim=-1) if not acqf._log else logmeanexp(output, dim=-1)
    293     )
    294 if assert_output_shape and not _verify_output_shape(
    295     acqf=acqf,
    296     X=X,
    297     output=output,
    298 ):
--> 299     raise AssertionError(
    300         \"Expected the output shape to match either the t-batch shape of \"
    301         \"X, or the `model.batch_shape` in the case of acquisition \"
    302         \"functions using batch models; but got output with shape \"
    303         f\"{output.shape} for X with shape {X.shape}.\"
    304     )
    305 return output

AssertionError: Expected the output shape to match either the t-batch shape of X, or the `model.batch_shape` in the case of acquisition functions using batch models; but got output with shape torch.Size([20, 2]) for X with shape torch.Size([20, 1, 1])."
}

Expected Behavior

Should have returned a single candidate point along with the acquisition function value.

System information

Please complete the following information:

  • BoTorch version: 0.11.3
  • GPyTorch version: 1.12
  • Torch version: 2.4.1
  • OS: macOS Sonoma 14.5

Additional context

I was able to generate candidate points by adding a posterior_transform in the posterior method (used ScalarizedPosteriorTransform to combine the outputs and form a single output posterior).

@SaiAakash SaiAakash added the bug Something isn't working label Sep 11, 2024
@Balandat
Copy link
Contributor

Thanks for raising this, yes it does appear that the posterior_transform argument is missing in the posterior() definition of that model. The fact that we're silently ignoring that argument by having it sucked up by the **kwargs here is not great (cc @esantorella who has worked on reducing this antipattern in a few other places before).

I was able to generate candidate points by adding a posterior_transform in the posterior method (used ScalarizedPosteriorTransform to combine the outputs and form a single output posterior).

Great to hear. Would you be willing to put up a PR with your fix?

@SaiAakash
Copy link
Contributor Author

PR #2531 fixes this. Thanks @Balandat !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants