From caaef96b5398650c601006f455effe7df52fdaa0 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 17 Oct 2024 17:42:38 -0700 Subject: [PATCH] Update selection of botorch_acqf_class in BoTorchModel._instantiate_acquisition (#2909) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2909 The previous logic relied on imperfect proxies for `is_moo`. This information is readily available on `TorchOptConfig`, so we can directly utilize it. Also simplified the function signature and updated the error message for robust optimization. Reviewed By: Balandat Differential Revision: D64545104 fbshipit-source-id: 3c517ccdff750cf5f156e83d1e9b3f2b393d80a7 --- ax/models/torch/botorch_modular/model.py | 18 +++++++----------- ax/models/torch/botorch_modular/utils.py | 21 ++++++++------------- ax/models/torch/tests/test_model.py | 5 +++-- ax/models/torch/tests/test_utils.py | 22 ++++++++++++++-------- 4 files changed, 32 insertions(+), 34 deletions(-) diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 068cce14945..c4315a89cfa 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -19,7 +19,7 @@ import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata, TGenMetadata -from ax.exceptions.core import UserInputError +from ax.exceptions.core import UnsupportedError, UserInputError from ax.models.torch.botorch import ( get_feature_importances_from_botorch_model, get_rounding_func, @@ -616,17 +616,13 @@ def _instantiate_acquisition( """ if not self._botorch_acqf_class: if torch_opt_config.risk_measure is not None: - # TODO[T131759261]: Implement selection of acqf for robust opt. - # This will depend on the properties of the robust search space and - # the risk measure being used. - raise NotImplementedError + raise UnsupportedError( + "Automated selection of `botorch_acqf_class` is not supported " + "for robust optimization with risk measures. Please specify " + "`botorch_acqf_class` as part of `model_kwargs`." + ) self._botorch_acqf_class = choose_botorch_acqf_class( - pending_observations=torch_opt_config.pending_observations, - outcome_constraints=torch_opt_config.outcome_constraints, - linear_constraints=torch_opt_config.linear_constraints, - fixed_features=torch_opt_config.fixed_features, - objective_thresholds=torch_opt_config.objective_thresholds, - objective_weights=torch_opt_config.objective_weights, + torch_opt_config=torch_opt_config ) return self.acquisition_class( diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 207c5c700cd..2a1cd9292d0 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -15,6 +15,7 @@ import torch from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import AxError, AxWarning, UnsupportedError +from ax.models.torch_base import TorchOptConfig from ax.models.types import TConfig from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger @@ -129,20 +130,14 @@ def choose_model_class( def choose_botorch_acqf_class( - pending_observations: list[Tensor] | None = None, - outcome_constraints: tuple[Tensor, Tensor] | None = None, - linear_constraints: tuple[Tensor, Tensor] | None = None, - fixed_features: dict[int, float] | None = None, - objective_thresholds: Tensor | None = None, - objective_weights: Tensor | None = None, + torch_opt_config: TorchOptConfig, ) -> type[AcquisitionFunction]: - """Chooses a BoTorch `AcquisitionFunction` class.""" - if objective_thresholds is not None or ( - # using objective_weights is a less-than-ideal fix given its ambiguity, - # the real fix would be to revisit the infomration passed down via - # the modelbridge (and be explicit about whether we scalarize or perform MOO) - objective_weights is not None and objective_weights.nonzero().numel() > 1 - ): + """Chooses a BoTorch ``AcquisitionFunction`` class. + + Current logic relies on ``TorchOptConfig.is_moo`` field to determine + whether to use qLogNEHVI (for MOO) or qLogNEI for (SOO). + """ + if torch_opt_config.is_moo: acqf_class = qLogNoisyExpectedHypervolumeImprovement else: acqf_class = qLogNoisyExpectedImprovement diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 16b1e57bb87..7280ae22524 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -175,6 +175,7 @@ def setUp(self) -> None: objective_weights=self.moo_objective_weights, objective_thresholds=self.moo_objective_thresholds, outcome_constraints=self.moo_outcome_constraints, + is_moo=True, ) def test_init(self) -> None: @@ -874,13 +875,13 @@ def test_MOO(self, _) -> None: torch_opt_config=self.moo_torch_opt_config, ) mock_get_outcome_constraint_transforms.assert_called_once() - ckwargs = mock_get_outcome_constraint_transforms.call_args[1] + ckwargs = mock_get_outcome_constraint_transforms.call_args.kwargs oc = ckwargs["outcome_constraints"] self.assertTrue(torch.equal(oc[0], subset_outcome_constraints[0])) self.assertTrue(torch.equal(oc[1], subset_outcome_constraints[1])) # Check input constructor args - ckwargs = mock_input_constructor.call_args[1] + ckwargs = mock_input_constructor.call_args.kwargs expected_kwargs = { "constraints", "bounds", diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index e23fac4eef1..4ab0f4677c4 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -26,6 +26,7 @@ use_model_list, ) from ax.models.torch.utils import _to_inequality_constraints +from ax.models.torch_base import TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast, not_none @@ -178,18 +179,23 @@ def test_choose_model_class(self) -> None: ) def test_choose_botorch_acqf_class(self) -> None: - self.assertEqual(qLogNoisyExpectedImprovement, choose_botorch_acqf_class()) self.assertEqual( - qLogNoisyExpectedHypervolumeImprovement, - choose_botorch_acqf_class(objective_thresholds=self.objective_thresholds), + qLogNoisyExpectedImprovement, + choose_botorch_acqf_class( + torch_opt_config=TorchOptConfig( + objective_weights=torch.tensor([1.0, 0.0]), + is_moo=False, + ) + ), ) self.assertEqual( qLogNoisyExpectedHypervolumeImprovement, - choose_botorch_acqf_class(objective_weights=torch.tensor([0.5, 0.5])), - ) - self.assertEqual( - qLogNoisyExpectedImprovement, - choose_botorch_acqf_class(objective_weights=torch.tensor([1.0, 0.0])), + choose_botorch_acqf_class( + torch_opt_config=TorchOptConfig( + objective_weights=torch.tensor([1.0, -1.0]), + is_moo=True, + ) + ), ) def test_construct_acquisition_and_optimizer_options(self) -> None: