From b1f8a7535cdf8f9358d48d2f1a00379237e454f8 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 24 Oct 2024 08:47:29 -0700 Subject: [PATCH] Remove multi-surrogate support from Acquisition classes (#2949) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2949 Multiple surrogates is not supported by any Acquisition classes and is being deprecated. This diff replaces the `surrogates: dict[str, Surrogate]` input to `Acquisition` classes with a simple `surrogate: Surrogate` input and cleans up the relevant code. A following diff will remove multiple-surrogate support from the `BoTorchModel` class. Reviewed By: Balandat Differential Revision: D64875386 fbshipit-source-id: 19f26f3c69f1e3c91d1204b84b4a645c70f557d1 --- .../torch/botorch_modular/acquisition.py | 54 +++++-------------- ax/models/torch/botorch_modular/model.py | 2 +- ax/models/torch/botorch_modular/sebo.py | 9 +--- ax/models/torch/botorch_modular/surrogate.py | 4 +- ax/models/torch/tests/test_acquisition.py | 39 +++++--------- ax/models/torch/tests/test_sebo.py | 21 ++------ 6 files changed, 32 insertions(+), 97 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 59e0a42f44f..c928c00730c 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -17,7 +17,7 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError +from ax.exceptions.core import SearchSpaceExhausted from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -65,9 +65,8 @@ class Acquisition(Base): of `BoTorchModel` and is not meant to be used outside of it. Args: - surrogates: Dict of name => Surrogate model pairs, with which this acquisition + surrogate: The Surrogate model, with which this acquisition function will be used. - NOTE: Only a single surrogate is currently supported! search_space_digest: A SearchSpaceDigest object containing metadata about the search space (e.g. bounds, parameter types). torch_opt_config: A TorchOptConfig object containing optimization @@ -78,30 +77,24 @@ class Acquisition(Base): Function` in BoTorch. """ - surrogates: dict[str, Surrogate] + surrogate: Surrogate acqf: AcquisitionFunction options: dict[str, Any] def __init__( self, - surrogates: dict[str, Surrogate], + surrogate: Surrogate, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], options: dict[str, Any] | None = None, ) -> None: - if len(surrogates) > 1: - raise UnsupportedError( - "The `Acquisition` class currently only supports a single surrogate." - ) - - self.surrogates = surrogates + self.surrogate = surrogate self.options = options or {} - primary_surrogate = next(iter(self.surrogates.values())) # Extract pending and observed points. X_pending, X_observed = _get_X_pending_and_observed( - Xs=primary_surrogate.Xs, + Xs=surrogate.Xs, objective_weights=torch_opt_config.objective_weights, bounds=search_space_digest.bounds, pending_observations=torch_opt_config.pending_observations, @@ -120,7 +113,7 @@ def __init__( # Subset model only to the outcomes we need for the optimization. if self.options.pop(Keys.SUBSET_MODEL, True): subset_model_results = subset_model( - model=primary_surrogate.model, + model=surrogate.model, objective_weights=torch_opt_config.objective_weights, outcome_constraints=torch_opt_config.outcome_constraints, objective_thresholds=torch_opt_config.objective_thresholds, @@ -131,7 +124,7 @@ def __init__( objective_thresholds = subset_model_results.objective_thresholds subset_idcs = subset_model_results.indices else: - model = primary_surrogate.model + model = surrogate.model objective_weights = torch_opt_config.objective_weights outcome_constraints = torch_opt_config.outcome_constraints objective_thresholds = torch_opt_config.objective_thresholds @@ -203,13 +196,11 @@ def __init__( # If there is a single dataset, this will be the dataset itself. # If there are multiple datasets, this will be a dict mapping the outcome names # to the corresponding datasets. - training_data = primary_surrogate.training_data + training_data = surrogate.training_data if len(training_data) == 1: training_data = training_data[0] else: - training_data = dict( - zip(none_throws(primary_surrogate._outcomes), training_data) - ) + training_data = dict(zip(none_throws(surrogate._outcomes), training_data)) acqf_inputs = input_constructor( training_data=training_data, @@ -230,35 +221,14 @@ def dtype(self) -> torch.dtype | None: """Torch data type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ - dtypes = { - label: surrogate.dtype for label, surrogate in self.surrogates.items() - } - - dtypes_list = list(dtypes.values()) - if dtypes_list.count(dtypes_list[0]) != len(dtypes_list): - raise ValueError( - f"Expected all Surrogates to have same dtype, found {dtypes}" - ) - - return dtypes_list[0] + return self.surrogate.dtype @property def device(self) -> torch.device | None: """Torch device type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ - - devices = { - label: surrogate.device for label, surrogate in self.surrogates.items() - } - - devices_list = list(devices.values()) - if devices_list.count(devices_list[0]) != len(devices_list): - raise ValueError( - f"Expected all Surrogates to have same device, found {devices}" - ) - - return devices_list[0] + return self.surrogate.device @property def objective_thresholds(self) -> Tensor | None: diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 73c35922288..c27a58a32df 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -626,7 +626,7 @@ def _instantiate_acquisition( ) return self.acquisition_class( - surrogates=self.surrogates, + surrogate=self.surrogate, botorch_acqf_class=self.botorch_acqf_class, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, diff --git a/ax/models/torch/botorch_modular/sebo.py b/ax/models/torch/botorch_modular/sebo.py index 0ad69be9afc..713042fd7de 100644 --- a/ax/models/torch/botorch_modular/sebo.py +++ b/ax/models/torch/botorch_modular/sebo.py @@ -21,7 +21,6 @@ from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch_base import TorchOptConfig -from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none from botorch.acquisition.acquisition import AcquisitionFunction @@ -58,16 +57,12 @@ class SEBOAcquisition(Acquisition): def __init__( self, - surrogates: dict[str, Surrogate], + surrogate: Surrogate, search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], options: dict[str, Any] | None = None, ) -> None: - if len(surrogates) > 1: - raise ValueError("SEBO does not support support multiple surrogates.") - surrogate = surrogates[Keys.ONLY_SURROGATE] - tkwargs: dict[str, Any] = {"dtype": surrogate.dtype, "device": surrogate.device} options = {} if options is None else options self.penalty_name: str = options.pop("penalty", "L0_norm") @@ -123,7 +118,7 @@ def __init__( if self.penalty_name == "L0_norm": self.deterministic_model._f.a.fill_(1e-6) super().__init__( - surrogates={"sebo": surrogate_f}, + surrogate=surrogate_f, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config_sebo, botorch_acqf_class=qLogNoisyExpectedHypervolumeImprovement, diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index b6b3d595f9e..0fab257ffca 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -692,8 +692,8 @@ def best_out_of_sample_point( # Avoiding circular import between `Surrogate` and `Acquisition`. from ax.models.torch.botorch_modular.acquisition import Acquisition - acqf = Acquisition( # TODO: For multi-fidelity, might need diff. class. - surrogates={"self": self}, + acqf = Acquisition( + surrogate=self, botorch_acqf_class=acqf_class, search_space_digest=search_space_digest, torch_opt_config=torch_opt_config, diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index b9b97bff456..0c3ce9e70d9 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -18,7 +18,7 @@ import numpy as np import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError +from ax.exceptions.core import SearchSpaceExhausted from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -178,7 +178,7 @@ def get_acquisition_function( botorch_acqf_class=( DummyOneShotAcquisitionFunction if one_shot else self.botorch_acqf_class ), - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, torch_opt_config=dataclasses.replace( self.torch_opt_config, fixed_features=fixed_features or {} @@ -194,7 +194,7 @@ def test_init_raises_when_missing_acqf_cls(self) -> None: with self.assertRaisesRegex(TypeError, ".* missing .* 'botorch_acqf_class'"): # pyre-ignore[20]: Argument `botorch_acqf_class` expected. Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, torch_opt_config=self.torch_opt_config, ) @@ -210,7 +210,7 @@ def test_init( mock_get_X: Mock, ) -> None: acquisition = Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, torch_opt_config=self.torch_opt_config, botorch_acqf_class=self.botorch_acqf_class, @@ -234,7 +234,7 @@ def test_init( # Call `subset_model` only when needed mock_subset_model.assert_called_with( - model=acquisition.surrogates["surrogate"].model, + model=acquisition.surrogate.model, objective_weights=self.objective_weights, outcome_constraints=self.outcome_constraints, objective_thresholds=self.objective_thresholds, @@ -265,7 +265,7 @@ def test_init_with_subset_model_false( return_value=self.constraints, ) as mock_get_outcome_constraint_transforms: acquisition = Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, torch_opt_config=self.torch_opt_config, botorch_acqf_class=self.botorch_acqf_class, @@ -275,14 +275,14 @@ def test_init_with_subset_model_false( # Check `get_botorch_objective_and_transform` kwargs mock_get_objective_and_transform.assert_called_once() _, ckwargs = mock_get_objective_and_transform.call_args - self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) + self.assertIs(ckwargs["model"], acquisition.surrogate.model) self.assertIs(ckwargs["objective_weights"], self.objective_weights) self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints) self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1])) # Check final `acqf` creation self.mock_input_constructor.assert_called_once() _, ckwargs = self.mock_input_constructor.call_args - self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model) + self.assertIs(ckwargs["model"], acquisition.surrogate.model) self.assertIs(ckwargs["objective"], botorch_objective) self.assertTrue(torch.equal(ckwargs["X_pending"], self.pending_observations[0])) for k, v in self.options.items(): @@ -422,7 +422,7 @@ def test_optimize_discrete(self) -> None: inequality_constraints=None, ) - expected_choices = torch.tensor([elt for elt in all_possible_choices]) + expected_choices = torch.tensor(all_possible_choices) expected_avoid = torch.cat([self.X, self.pending_observations[0]], dim=-2) kwargs = mock_optimize_acqf_discrete.call_args.kwargs @@ -702,7 +702,7 @@ def test_init_moo( is_moo=True, ) acquisition = Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, botorch_acqf_class=acqf_class, search_space_digest=self.search_space_digest, torch_opt_config=torch_opt_config, @@ -736,7 +736,7 @@ def test_init_moo( ) ) acquisition = Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, botorch_acqf_class=acqf_class, torch_opt_config=dataclasses.replace( @@ -757,7 +757,7 @@ def test_init_moo( self.assertTrue(np.isnan(acquisition.objective_thresholds[2].item())) # With partial thresholds. acquisition = Acquisition( - surrogates={"surrogate": self.surrogate}, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, botorch_acqf_class=acqf_class, torch_opt_config=dataclasses.replace( @@ -784,18 +784,3 @@ def test_init_moo( def test_init_no_X_observed(self) -> None: self.test_init_moo(with_no_X_observed=True, with_outcome_constraints=False) - - def test_init_multiple_surrogates(self) -> None: - with self.assertRaisesRegex( - UnsupportedError, "currently only supports a single surrogate" - ): - Acquisition( - surrogates={ - "surrogate_1": self.surrogate, - "surrogate_2": self.surrogate, - }, - search_space_digest=self.search_space_digest, - torch_opt_config=self.torch_opt_config, - botorch_acqf_class=self.botorch_acqf_class, - options=self.options, - ) diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index 2e3638ad041..81d68861af8 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -119,7 +119,7 @@ def get_acquisition_function( ) -> SEBOAcquisition: return SEBOAcquisition( botorch_acqf_class=qNoisyExpectedHypervolumeImprovement, - surrogates={Keys.ONLY_SURROGATE: self.surrogates}, + surrogate=self.surrogates, search_space_digest=self.search_space_digest, torch_opt_config=dataclasses.replace( torch_opt_config or self.torch_opt_config, @@ -133,7 +133,7 @@ def test_init(self) -> None: options={"target_point": self.target_point}, ) # Check that determinstic metric is added to surrogate - surrogate = acquisition1.surrogates["sebo"] + surrogate = acquisition1.surrogate model_list = not_none(surrogate._model) self.assertIsInstance(model_list, ModelList) self.assertIsInstance(model_list.models[0], SingleTaskGP) @@ -167,7 +167,7 @@ def test_init(self) -> None: options={"penalty": "L1_norm", "target_point": self.target_point}, ) self.assertEqual(acquisition2.penalty_name, "L1_norm") - surrogate = acquisition2.surrogates["sebo"] + surrogate = acquisition2.surrogate model_list = not_none(surrogate._model) self.assertIsInstance(model_list.models[1]._f, functools.partial) self.assertIs(model_list.models[1]._f.func, L1_norm_func) @@ -181,21 +181,6 @@ def test_init(self) -> None: options={"penalty": "L2_norm", "target_point": self.target_point}, ) - # assert error raise if multiple surrogates are given - with self.assertRaisesRegex( - ValueError, "SEBO does not support support multiple surrogates." - ): - SEBOAcquisition( - botorch_acqf_class=qNoisyExpectedHypervolumeImprovement, - surrogates={ - Keys.ONLY_SURROGATE: self.surrogates, - "sebo2": self.surrogates, - }, - search_space_digest=self.search_space_digest, - torch_opt_config=self.torch_opt_config, - options=self.options, - ) - # assert error raise if target point is not given with self.assertRaisesRegex(ValueError, "please provide target point."): self.get_acquisition_function(options={"penalty": "L1_norm"})