Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

raise exception if X_pending is set on underlying AF in prior-guided AF #2505

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions botorch/acquisition/prior_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
from botorch.exceptions.errors import BotorchError
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -60,8 +61,16 @@ def __init__(
[Hvarfner2022]_.
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
Note: X_pending should be provided as an argument to or set on
`PriorGuidedAcquisitionFunction`, but not set on the underlying
acquisition function.
"""
super().__init__(model=acq_function.model)
if getattr(acq_function, "X_pending", None) is not None:
raise BotorchError(
"X_pending is set on acq_function, but should be set on "
"`PriorGuidedAcquisitionFunction`."
)
self.acq_func = acq_function
self.prior_module = prior_module
self._log = log
Expand Down
16 changes: 16 additions & 0 deletions test/acquisition/test_prior_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction
from botorch.exceptions.errors import BotorchError
from botorch.models import SingleTaskGP
from botorch.utils.testing import BotorchTestCase
from botorch.utils.transforms import match_batch_shape
Expand Down Expand Up @@ -122,3 +123,18 @@ def test_prior_guided_mc_acquisition_function(self):
expected_val = ei._sample_reduction(ei._q_reduction(weighted_val))

self.assertTrue(torch.equal(val, expected_val))

def test_X_pending_error(self) -> None:
X_pending = torch.rand(2, 3, dtype=torch.double, device=self.device)
model = SingleTaskGP(train_X=self.train_X, train_Y=self.train_Y)
ei = qExpectedImprovement(model=model, best_f=0.0)
ei.set_X_pending(X_pending)
msg = (
"X_pending is set on acq_function, but should be set on "
"`PriorGuidedAcquisitionFunction`."
)
with self.assertRaisesRegex(BotorchError, msg):
PriorGuidedAcquisitionFunction(
acq_function=ei,
prior_module=self.prior,
)
Loading