From 258ac766b19856b8c3cb22ecf0c66afc71195afc Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Tue, 3 Sep 2024 12:33:42 -0700 Subject: [PATCH] raise exception if X_pending is set on underlying AF in prior-guided AF (#2505) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2505 See title. And discussion in https://github.com/pytorch/botorch/issues/2493. Reviewed By: saitcakmak Differential Revision: D62143198 --- botorch/acquisition/prior_guided.py | 9 +++++++++ test/acquisition/test_prior_guided.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/botorch/acquisition/prior_guided.py b/botorch/acquisition/prior_guided.py index fd5fe97d11..d7fa9d81c0 100644 --- a/botorch/acquisition/prior_guided.py +++ b/botorch/acquisition/prior_guided.py @@ -21,6 +21,7 @@ from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.monte_carlo import SampleReducingMCAcquisitionFunction +from botorch.exceptions.errors import BotorchError from botorch.utils.transforms import concatenate_pending_points, t_batch_mode_transform from torch import Tensor @@ -60,8 +61,16 @@ def __init__( [Hvarfner2022]_. X_pending: `n x d` Tensor with `n` `d`-dim design points that have been submitted for evaluation but have not yet been evaluated. + Note: X_pending should be provided as an argument to or set on + `PriorGuidedAcquisitionFunction`, but not set on the underlying + acquisition function. """ super().__init__(model=acq_function.model) + if getattr(acq_function, "X_pending", None) is not None: + raise BotorchError( + "X_pending is set on acq_function, but should be set on " + "`PriorGuidedAcquisitionFunction`." + ) self.acq_func = acq_function self.prior_module = prior_module self._log = log diff --git a/test/acquisition/test_prior_guided.py b/test/acquisition/test_prior_guided.py index c2ea6ced12..079d2d123a 100644 --- a/test/acquisition/test_prior_guided.py +++ b/test/acquisition/test_prior_guided.py @@ -10,6 +10,7 @@ from botorch.acquisition.analytic import ExpectedImprovement from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.acquisition.prior_guided import PriorGuidedAcquisitionFunction +from botorch.exceptions.errors import BotorchError from botorch.models import SingleTaskGP from botorch.utils.testing import BotorchTestCase from botorch.utils.transforms import match_batch_shape @@ -122,3 +123,18 @@ def test_prior_guided_mc_acquisition_function(self): expected_val = ei._sample_reduction(ei._q_reduction(weighted_val)) self.assertTrue(torch.equal(val, expected_val)) + + def test_X_pending_error(self) -> None: + X_pending = torch.rand(2, 3, dtype=torch.double, device=self.device) + model = SingleTaskGP(train_X=self.train_X, train_Y=self.train_Y) + ei = qExpectedImprovement(model=model, best_f=0.0) + ei.set_X_pending(X_pending) + msg = ( + "X_pending is set on acq_function, but should be set on " + "`PriorGuidedAcquisitionFunction`." + ) + with self.assertRaisesRegex(BotorchError, msg): + PriorGuidedAcquisitionFunction( + acq_function=ei, + prior_module=self.prior, + )