Skip to content

Commit

Permalink
raise exception if X_pending is set on underlying AF in prior-guided …
Browse files Browse the repository at this point in the history
…AF (#2505)

Summary:
Pull Request resolved: #2505

See title. And discussion in #2493.

Reviewed By: saitcakmak

Differential Revision: D62143198
  • Loading branch information
sdaulton authored and facebook-github-bot committed Sep 3, 2024
1 parent c47c01c commit 258ac76
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
9 changes: 9 additions & 0 deletions botorch/acquisition/prior_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction
from botorch.exceptions.errors import BotorchError
from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform
from torch import Tensor

Expand Down Expand Up @@ -60,8 +61,16 @@ def __init__(
[Hvarfner2022]_.
X_pending: `n x d` Tensor with `n` `d`-dim design points that have
been submitted for evaluation but have not yet been evaluated.
Note: X_pending should be provided as an argument to or set on
`PriorGuidedAcquisitionFunction`, but not set on the underlying
acquisition function.
"""
super().__init__(model=acq_function.model)
if getattr(acq_function, "X_pending", None) is not None:
raise BotorchError(
"X_pending is set on acq_function, but should be set on "
"`PriorGuidedAcquisitionFunction`."
)
self.acq_func = acq_function
self.prior_module = prior_module
self._log = log
Expand Down
16 changes: 16 additions & 0 deletions test/acquisition/test_prior_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from botorch.acquisition.analytic import ExpectedImprovement
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction
from botorch.exceptions.errors import BotorchError
from botorch.models import SingleTaskGP
from botorch.utils.testing import BotorchTestCase
from botorch.utils.transforms import match_batch_shape
Expand Down Expand Up @@ -122,3 +123,18 @@ def test_prior_guided_mc_acquisition_function(self):
expected_val = ei._sample_reduction(ei._q_reduction(weighted_val))

self.assertTrue(torch.equal(val, expected_val))

def test_X_pending_error(self) -> None:
X_pending = torch.rand(2, 3, dtype=torch.double, device=self.device)
model = SingleTaskGP(train_X=self.train_X, train_Y=self.train_Y)
ei = qExpectedImprovement(model=model, best_f=0.0)
ei.set_X_pending(X_pending)
msg = (
"X_pending is set on acq_function, but should be set on "
"`PriorGuidedAcquisitionFunction`."
)
with self.assertRaisesRegex(BotorchError, msg):
PriorGuidedAcquisitionFunction(
acq_function=ei,
prior_module=self.prior,
)

0 comments on commit 258ac76

Please sign in to comment.