From 9c194f5b72922823bc57a484d4144e6fb90601eb Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 17 Oct 2024 14:26:36 -0700 Subject: [PATCH] Validate that an MO acqf is used for MOO in MBM/Acquistion Summary: This diff adds a validation that botorch_acqf_class is an MO acqf when `TorchOptConfig.is_moo is True`. This should eliminate bugs like https://github.com/facebook/Ax/issues/2519, which can happen since the downstream code will otherwise assume SOO. Note that this only solves MBM side of the bug. Legacy code will still have the buggy behavior. Differential Revision: D64563992 --- .../torch/botorch_modular/acquisition.py | 26 ++++++++--- ax/models/torch/tests/test_acquisition.py | 45 ++++++++++++++----- 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 7581ab9b178..06f27b36f1c 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -19,7 +19,7 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -41,6 +41,10 @@ from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.input_constructors import get_acqf_input_constructor from botorch.acquisition.knowledge_gradient import qKnowledgeGradient +from botorch.acquisition.multi_objective.base import ( + MultiObjectiveAnalyticAcquisitionFunction, + MultiObjectiveMCAcquisitionFunction, +) from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform from botorch.acquisition.risk_measures import RiskMeasureMCObjective from botorch.models.model import Model, ModelDict @@ -101,6 +105,19 @@ def __init__( self.surrogates = surrogates self.options = options or {} + if torch_opt_config.is_moo and not issubclass( + botorch_acqf_class, + ( + MultiObjectiveAnalyticAcquisitionFunction, + MultiObjectiveMCAcquisitionFunction, + ), + ): + raise UserInputError( + "Acquisition requires a `MultiObjectiveAnalyticAcquisitionFunction` " + "or a `MultiObjectiveMCAcquisitionFunction` class when there are " + f"multiple objectives. Received {botorch_acqf_class=}." + ) + # Compute pending and observed points for each surrogate Xs_pending_and_observed = { name: _get_X_pending_and_observed( @@ -215,12 +232,11 @@ def __init__( outcome_constraints = torch_opt_config.outcome_constraints objective_thresholds = torch_opt_config.objective_thresholds subset_idcs = None - # If objective weights suggest multiple objectives but objective - # thresholds are not specified, infer them using the model that - # has already been subset to avoid re-subsetting it within + # If MOO and some objective thresholds are not specified, infer them using + # the model that has already been subset to avoid re-subsetting it within # `inter_objective_thresholds`. if ( - objective_weights.nonzero().numel() > 1 + torch_opt_config.is_moo and ( self._objective_thresholds is None or self._objective_thresholds[torch_opt_config.objective_weights != 0] diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index b9cdd29d31f..1e121c84092 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -18,7 +18,7 @@ import numpy as np import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -43,6 +43,9 @@ ) from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement +from botorch.acquisition.multi_objective.base import ( + MultiObjectiveAnalyticAcquisitionFunction, +) from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, ) @@ -89,22 +92,30 @@ def evaluate(self, X: Tensor, **kwargs: Any) -> Tensor: return X.sum(dim=-1) +class DummyMultiObjectiveAcquisitionFunction( + DummyAcquisitionFunction, MultiObjectiveAnalyticAcquisitionFunction +): + # Dummy acquisition function for testing multi-objective setup. + ... + + class AcquisitionTest(TestCase): def setUp(self) -> None: super().setUp() qNEI_input_constructor = get_acqf_input_constructor(qNoisyExpectedImprovement) + # Adding wrapping here to be able to count calls and inspect arguments. self.mock_input_constructor = mock.MagicMock( qNEI_input_constructor, side_effect=qNEI_input_constructor ) - # Adding wrapping here to be able to count calls and inspect arguments. - _register_acqf_input_constructor( - acqf_cls=DummyAcquisitionFunction, - input_constructor=self.mock_input_constructor, - ) - _register_acqf_input_constructor( - acqf_cls=DummyOneShotAcquisitionFunction, - input_constructor=self.mock_input_constructor, - ) + for acqf_class in ( + DummyAcquisitionFunction, + DummyOneShotAcquisitionFunction, + DummyMultiObjectiveAcquisitionFunction, + ): + _register_acqf_input_constructor( + acqf_cls=acqf_class, + input_constructor=self.mock_input_constructor, + ) tkwargs: dict[str, Any] = {"dtype": torch.double} self.botorch_model_class = SingleTaskGP self.surrogate = Surrogate(botorch_model_class=self.botorch_model_class) @@ -738,7 +749,7 @@ def test_init_moo( with_outcome_constraints: bool = True, ) -> None: acqf_class = ( - DummyAcquisitionFunction + DummyMultiObjectiveAcquisitionFunction if with_no_X_observed else qNoisyExpectedHypervolumeImprovement ) @@ -776,7 +787,19 @@ def test_init_moo( objective_weights=moo_objective_weights, outcome_constraints=outcome_constraints, objective_thresholds=moo_objective_thresholds, + is_moo=True, ) + with self.assertRaisesRegex( + UserInputError, "when there are multiple objectives" + ): + Acquisition( + surrogates={"surrogate": self.surrogate}, + botorch_acqf_class=DummyAcquisitionFunction, + search_space_digest=self.search_space_digest, + torch_opt_config=torch_opt_config, + options=self.options, + ) + acquisition = Acquisition( surrogates={"surrogate": self.surrogate}, botorch_acqf_class=acqf_class,