diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 7581ab9b178..06f27b36f1c 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -19,7 +19,7 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -41,6 +41,10 @@ from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.input_constructors import get_acqf_input_constructor from botorch.acquisition.knowledge_gradient import qKnowledgeGradient +from botorch.acquisition.multi_objective.base import ( + MultiObjectiveAnalyticAcquisitionFunction, + MultiObjectiveMCAcquisitionFunction, +) from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform from botorch.acquisition.risk_measures import RiskMeasureMCObjective from botorch.models.model import Model, ModelDict @@ -101,6 +105,19 @@ def __init__( self.surrogates = surrogates self.options = options or {} + if torch_opt_config.is_moo and not issubclass( + botorch_acqf_class, + ( + MultiObjectiveAnalyticAcquisitionFunction, + MultiObjectiveMCAcquisitionFunction, + ), + ): + raise UserInputError( + "Acquisition requires a `MultiObjectiveAnalyticAcquisitionFunction` " + "or a `MultiObjectiveMCAcquisitionFunction` class when there are " + f"multiple objectives. Received {botorch_acqf_class=}." + ) + # Compute pending and observed points for each surrogate Xs_pending_and_observed = { name: _get_X_pending_and_observed( @@ -215,12 +232,11 @@ def __init__( outcome_constraints = torch_opt_config.outcome_constraints objective_thresholds = torch_opt_config.objective_thresholds subset_idcs = None - # If objective weights suggest multiple objectives but objective - # thresholds are not specified, infer them using the model that - # has already been subset to avoid re-subsetting it within + # If MOO and some objective thresholds are not specified, infer them using + # the model that has already been subset to avoid re-subsetting it within # `inter_objective_thresholds`. if ( - objective_weights.nonzero().numel() > 1 + torch_opt_config.is_moo and ( self._objective_thresholds is None or self._objective_thresholds[torch_opt_config.objective_weights != 0] diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index b9cdd29d31f..1e121c84092 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -18,7 +18,7 @@ import numpy as np import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -43,6 +43,9 @@ ) from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement +from botorch.acquisition.multi_objective.base import ( + MultiObjectiveAnalyticAcquisitionFunction, +) from botorch.acquisition.multi_objective.monte_carlo import ( qNoisyExpectedHypervolumeImprovement, ) @@ -89,22 +92,30 @@ def evaluate(self, X: Tensor, **kwargs: Any) -> Tensor: return X.sum(dim=-1) +class DummyMultiObjectiveAcquisitionFunction( + DummyAcquisitionFunction, MultiObjectiveAnalyticAcquisitionFunction +): + # Dummy acquisition function for testing multi-objective setup. + ... + + class AcquisitionTest(TestCase): def setUp(self) -> None: super().setUp() qNEI_input_constructor = get_acqf_input_constructor(qNoisyExpectedImprovement) + # Adding wrapping here to be able to count calls and inspect arguments. self.mock_input_constructor = mock.MagicMock( qNEI_input_constructor, side_effect=qNEI_input_constructor ) - # Adding wrapping here to be able to count calls and inspect arguments. - _register_acqf_input_constructor( - acqf_cls=DummyAcquisitionFunction, - input_constructor=self.mock_input_constructor, - ) - _register_acqf_input_constructor( - acqf_cls=DummyOneShotAcquisitionFunction, - input_constructor=self.mock_input_constructor, - ) + for acqf_class in ( + DummyAcquisitionFunction, + DummyOneShotAcquisitionFunction, + DummyMultiObjectiveAcquisitionFunction, + ): + _register_acqf_input_constructor( + acqf_cls=acqf_class, + input_constructor=self.mock_input_constructor, + ) tkwargs: dict[str, Any] = {"dtype": torch.double} self.botorch_model_class = SingleTaskGP self.surrogate = Surrogate(botorch_model_class=self.botorch_model_class) @@ -738,7 +749,7 @@ def test_init_moo( with_outcome_constraints: bool = True, ) -> None: acqf_class = ( - DummyAcquisitionFunction + DummyMultiObjectiveAcquisitionFunction if with_no_X_observed else qNoisyExpectedHypervolumeImprovement ) @@ -776,7 +787,19 @@ def test_init_moo( objective_weights=moo_objective_weights, outcome_constraints=outcome_constraints, objective_thresholds=moo_objective_thresholds, + is_moo=True, ) + with self.assertRaisesRegex( + UserInputError, "when there are multiple objectives" + ): + Acquisition( + surrogates={"surrogate": self.surrogate}, + botorch_acqf_class=DummyAcquisitionFunction, + search_space_digest=self.search_space_digest, + torch_opt_config=torch_opt_config, + options=self.options, + ) + acquisition = Acquisition( surrogates={"surrogate": self.surrogate}, botorch_acqf_class=acqf_class,