diff --git a/botorch/acquisition/input_constructors.py b/botorch/acquisition/input_constructors.py index 06c069fd71..0714f79fec 100644 --- a/botorch/acquisition/input_constructors.py +++ b/botorch/acquisition/input_constructors.py @@ -1240,30 +1240,35 @@ def construct_inputs_qKG( objective: Optional[MCAcquisitionObjective] = None, posterior_transform: Optional[PosteriorTransform] = None, num_fantasies: int = 64, + with_current_value: bool = False, **optimize_objective_kwargs: TOptimizeObjectiveKwargs, ) -> dict[str, Any]: r"""Construct kwargs for `qKnowledgeGradient` constructor.""" - X = _get_dataset_field(training_data, "X", first_only=True) - _bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device) - - _, current_value = optimize_objective( - model=model, - bounds=_bounds.t(), - q=1, - objective=objective, - posterior_transform=posterior_transform, - **optimize_objective_kwargs, - ) - - return { + inputs_qkg = { "model": model, "objective": objective, "posterior_transform": posterior_transform, "num_fantasies": num_fantasies, - "current_value": current_value.detach().cpu().max(), } + if with_current_value: + + X = _get_dataset_field(training_data, "X", first_only=True) + _bounds = torch.as_tensor(bounds, dtype=X.dtype, device=X.device) + + _, current_value = optimize_objective( + model=model, + bounds=_bounds.t(), + q=1, + objective=objective, + posterior_transform=posterior_transform, + **optimize_objective_kwargs, + ) + inputs_qkg["current_value"] = current_value.detach().cpu().max() + + return inputs_qkg + @acqf_input_constructor(qMultiFidelityKnowledgeGradient) def construct_inputs_qMFKG( diff --git a/test/acquisition/test_input_constructors.py b/test/acquisition/test_input_constructors.py index f07bf002ef..3d2abf13df 100644 --- a/test/acquisition/test_input_constructors.py +++ b/test/acquisition/test_input_constructors.py @@ -1257,24 +1257,39 @@ def test_construct_inputs_qLogNParEGO(self) -> None: class TestKGandESAcquisitionFunctionInputConstructors(InputConstructorBaseTestCase): def test_construct_inputs_kg(self) -> None: - current_value = torch.tensor(1.23) - with mock.patch( - target="botorch.acquisition.input_constructors.optimize_objective", - return_value=(None, current_value), - ): - from botorch.acquisition import input_constructors + func = get_acqf_input_constructor(qKnowledgeGradient) - func = input_constructors.get_acqf_input_constructor(qKnowledgeGradient) + with self.subTest("test_with_current_value"): + + current_value = torch.tensor(1.23) + + with mock.patch( + target="botorch.acquisition.input_constructors.optimize_objective", + return_value=(None, current_value), + ): + + kwargs = func( + model=mock.Mock(), + training_data=self.blockX_blockY, + objective=LinearMCObjective(torch.rand(2)), + bounds=self.bounds, + num_fantasies=33, + with_current_value=True, + ) + + self.assertEqual(kwargs["num_fantasies"], 33) + self.assertEqual(kwargs["current_value"], current_value) + + with self.subTest("test_without_current_value"): kwargs = func( model=mock.Mock(), training_data=self.blockX_blockY, objective=LinearMCObjective(torch.rand(2)), bounds=self.bounds, num_fantasies=33, + with_current_value=False, ) - - self.assertEqual(kwargs["num_fantasies"], 33) - self.assertEqual(kwargs["current_value"], current_value) + self.assertNotIn("current_value", kwargs) def test_construct_inputs_mes(self) -> None: func = get_acqf_input_constructor(qMaxValueEntropy)