diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index c1065b0ee90..a1aa2943151 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -30,7 +30,6 @@ from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel -from ax.utils.common.constants import Keys from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -74,11 +73,9 @@ def test_botorch_modular(self) -> None: self.assertEqual(gpei.model.botorch_acqf_class, qExpectedImprovement) self.assertEqual(gpei.model.acquisition_class, Acquisition) self.assertEqual(gpei.model.acquisition_options, {"best_f": 0.0}) - self.assertIsInstance(gpei.model.surrogates[Keys.AUTOSET_SURROGATE], Surrogate) + self.assertIsInstance(gpei.model.surrogate, Surrogate) # SingleTaskGP should be picked. - self.assertIsInstance( - gpei.model.surrogates[Keys.AUTOSET_SURROGATE].model, SingleTaskGP - ) + self.assertIsInstance(gpei.model.surrogate.model, SingleTaskGP) gr = gpei.gen(n=1) self.assertIsNotNone(gr.best_arm_predictions) @@ -96,14 +93,10 @@ def test_SAASBO(self) -> None: self.assertIsInstance(saasbo, TorchModelBridge) self.assertEqual(saasbo._model_key, "SAASBO") self.assertIsInstance(saasbo.model, BoTorchModel) - surrogate_specs = saasbo.model.surrogate_specs + surrogate_spec = saasbo.model.surrogate_spec self.assertEqual( - surrogate_specs, - { - "SAASBO_Surrogate": SurrogateSpec( - botorch_model_class=SaasFullyBayesianSingleTaskGP - ) - }, + surrogate_spec, + SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP), ) self.assertEqual( saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index cae2f7bd67e..dd0bbafe13e 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -265,6 +265,7 @@ def __init__( "instead. If you run into a use case that is not supported by MBM, " "please raise this with an issue at https://github.com/facebook/Ax", DeprecationWarning, + stacklevel=2, ) self.model_constructor = model_constructor self.model_predictor = model_predictor diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index c928c00730c..38e8a358fe7 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -336,8 +336,7 @@ def optimize( ) return candidates, acqf_values, arm_weights - # 2. Handle search spaces with discrete features. - # 2a. Handle the fully discrete search space. + # 2. Handle fully discrete search spaces. if optimizer in ( "optimize_acqf_discrete", "optimize_acqf_discrete_local_search", @@ -384,7 +383,7 @@ def optimize( ) return candidates, acqf_values, arm_weights - # 2b. Handle mixed search spaces that have discrete and continuous features. + # 3. Handle mixed search spaces that have discrete and continuous features. # Only sequential optimization is supported for `optimize_acqf_mixed`. candidates, acqf_values = optimize_acqf_mixed( acq_function=self.acqf, diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index c27a58a32df..c48165377a3 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -7,13 +7,11 @@ # pyre-strict import dataclasses +import warnings from collections import OrderedDict -from collections.abc import Callable, Mapping, Sequence -from copy import deepcopy +from collections.abc import Mapping, Sequence from dataclasses import dataclass, field -from functools import wraps -from itertools import chain -from typing import Any, TypeVar +from typing import Any import numpy as np import torch @@ -30,7 +28,6 @@ check_outcome_dataset_match, choose_botorch_acqf_class, construct_acquisition_and_optimizer_options, - get_subset_datasets, ) from ax.models.torch.utils import _to_inequality_constraints from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig @@ -50,26 +47,6 @@ from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood from torch import Tensor -T = TypeVar("T") - - -def single_surrogate_only(f: Callable[..., T]) -> Callable[..., T]: - """ - For use as a decorator on functions only implemented for BotorchModels with a - single Surrogate. - """ - - @wraps(f) - def impl(self: "BoTorchModel", *args: list[Any], **kwargs: dict[str, Any]) -> T: - if len(self._surrogates) != 1: - raise NotImplementedError( - f"{f.__name__} not implemented for multi-surrogate case. Found " - f"{self.surrogates=}." - ) - return f(self, *args, **kwargs) - - return impl - @dataclass(frozen=True) class SurrogateSpec: @@ -110,41 +87,39 @@ class BoTorchModel(TorchModel, Base): construction, incomplete, and should be treated as alpha versions only.** - Modular `Model` class for combining BoTorch subcomponents - in Ax. Specified via `Surrogate` and `Acquisition`, which wrap - BoTorch `Model` and `AcquisitionFunction`, respectively, for + Modular ``Model`` class for combining BoTorch subcomponents + in Ax. Specified via ``Surrogate`` and ``Acquisition``, which wrap + BoTorch ``Model`` and ``AcquisitionFunction``, respectively, for convenient use in Ax. Args: - acquisition_class: Type of `Acquisition` to be used in + acquisition_class: Type of ``Acquisition`` to be used in this model, auto-selected based on experiment and data if not specified. acquisition_options: Optional dict of kwargs, passed to - the constructor of BoTorch `AcquisitionFunction`. - botorch_acqf_class: Type of `AcquisitionFunction` to be + the constructor of BoTorch ``AcquisitionFunction``. + botorch_acqf_class: Type of ``AcquisitionFunction`` to be used in this model, auto-selected based on experiment and data if not specified. - surrogate_specs: Optional Mapping of names onto SurrogateSpecs, which specify - how to initialize specific Surrogates to model specific outcomes. If None - is provided a single Surrogate will be created and set up automatically - based on the data provided. - surrogate: In liu of SurrogateSpecs, an instance of `Surrogate` may be - provided to be used as the sole Surrogate for all outcomes + surrogate_spec: An optional ``SurrogateSpec`` object specifying how to + construct the ``Surrogate`` and the underlying BoTorch ``Model``. + surrogate_specs: DEPRECATED. Please use ``surrogate_spec`` instead. + surrogate: In lieu of ``SurrogateSpec``, an instance of ``Surrogate`` may + be provided. In most cases, ``surrogate_spec`` should be used instead. refit_on_cv: Whether to reoptimize model parameters during call to - `BoTorchmodel.cross_validate`. + ``BoTorchmodel.cross_validate``. warm_start_refit: Whether to load parameters from either the provided - state dict or the state dict of the current BoTorch `Model` during + state dict or the state dict of the current BoTorch ``Model`` during refitting. If False, model parameters will be reoptimized from scratch on refit. NOTE: This setting is ignored during - `cross_validate` if the corresponding `refit_on_...` is False. + ``cross_validate`` if ``refit_on_cv`` is False. """ acquisition_class: type[Acquisition] acquisition_options: dict[str, Any] - surrogate_specs: dict[str, SurrogateSpec] - _surrogates: dict[str, Surrogate] - _output_order: list[int] | None = None + surrogate_spec: SurrogateSpec | None + _surrogate: Surrogate | None _botorch_acqf_class: type[AcquisitionFunction] | None _search_space_digest: SearchSpaceDigest | None = None @@ -152,53 +127,38 @@ class BoTorchModel(TorchModel, Base): def __init__( self, + surrogate_spec: SurrogateSpec | None = None, surrogate_specs: Mapping[str, SurrogateSpec] | None = None, surrogate: Surrogate | None = None, acquisition_class: type[Acquisition] | None = None, acquisition_options: dict[str, Any] | None = None, botorch_acqf_class: type[AcquisitionFunction] | None = None, - # TODO: [T168715924] Revisit these "refit" arguments. refit_on_cv: bool = False, warm_start_refit: bool = True, ) -> None: - # Ensure only surrogate_specs or surrogate is provided - if surrogate_specs and surrogate: + # Check that only one surrogate related option is provided. + if bool(surrogate_spec) + bool(surrogate_specs) + bool(surrogate) > 1: raise UserInputError( - "Only one of `surrogate_specs` and `surrogate` arguments is expected." + "Only one of `surrogate_spec`, `surrogate_specs`, and `surrogate` " + "can be specified. Please use `surrogate_spec`." ) - - # Ensure each outcome is only modeled by one Surrogate in the SurrogateSpecs if surrogate_specs is not None: - outcomes_by_surrogate_label = { - label: spec.outcomes for label, spec in surrogate_specs.items() - } - all_outcomes = list( - chain.from_iterable(outcomes_by_surrogate_label.values()) - ) - if len(all_outcomes) != len(set(all_outcomes)): - raise UserInputError( - "Each outcome may be modeled by only one Surrogate, found " - f"{outcomes_by_surrogate_label}" + if len(surrogate_specs) > 1: + raise DeprecationWarning( + "Support for multiple `Surrogate`s has been deprecated. " + "Please use the `surrogate_spec` input in the future to " + "specify a single `Surrogate`." ) - - # Ensure user does not use reserved Surrogate labels - if ( - surrogate_specs is not None - and len( - {Keys.ONLY_SURROGATE, Keys.AUTOSET_SURROGATE} - surrogate_specs.keys() + warnings.warn( + "The `surrogate_specs` argument is deprecated in favor of " + "`surrogate_spec`, which accepts a single `SurrogateSpec` object. " + "Please use `surrogate_spec` in the future.", + DeprecationWarning, + stacklevel=2, ) - < 2 - ): - raise UserInputError( - f"SurrogateSpecs may not be labeled {Keys.ONLY_SURROGATE} or " - f"{Keys.AUTOSET_SURROGATE}, these are reserved." - ) - - self.surrogate_specs = dict((surrogate_specs or {}).items()) - if surrogate is not None: - self._surrogates = {Keys.ONLY_SURROGATE: surrogate} - else: - self._surrogates = {} + surrogate_spec = next(iter(surrogate_specs.values())) + self.surrogate_spec = surrogate_spec + self._surrogate = surrogate self.acquisition_class = acquisition_class or Acquisition self.acquisition_options = acquisition_options or {} @@ -208,18 +168,13 @@ def __init__( self.warm_start_refit = warm_start_refit @property - def surrogates(self) -> dict[str, Surrogate]: - """Surrogates by label""" - return self._surrogates - - @property - @single_surrogate_only def surrogate(self) -> Surrogate: - """Surrogate, if there is only one.""" - return next(iter(self.surrogates.values())) + """Returns the ``Surrogate``, if it has been constructed.""" + if self._surrogate is None: + raise ValueError("Surrogate has not yet been constructed.") + return self._surrogate @property - @single_surrogate_only def Xs(self) -> list[Tensor]: """A list of tensors, each of shape ``batch_shape x n_i x d``, where `n_i` is the number of training inputs for the i-th model. @@ -243,8 +198,7 @@ def fit( datasets: Sequence[SupervisedDataset], search_space_digest: SearchSpaceDigest, candidate_metadata: list[list[TCandidateMetadata]] | None = None, - # state dict by surrogate label - state_dicts: Mapping[str, OrderedDict[str, Tensor]] | None = None, + state_dict: OrderedDict[str, Tensor] | None = None, refit: bool = True, **additional_model_inputs: Any, ) -> None: @@ -257,9 +211,8 @@ def fit( metadata on the features in the datasets. candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. - state_dicts: Optional state dict to load by model label as passed in via - surrogate_specs. If using a single, pre-instantiated model use - `Keys.ONLY_SURROGATE. + state_dict: An optional model statedict for the underlying ``Surrogate``. + Primarily used in ``BoTorchModel.cross_validate``. refit: Whether to re-optimize model parameters. additional_model_inputs: Additional kwargs to pass to the model input constructor in ``Surrogate.fit``. @@ -272,127 +225,51 @@ def fit( # Store search space info for later use (e.g. during generation) self._search_space_digest = search_space_digest - # Step 0. If the user passed in a preconstructed surrogate we won't have a - # SurrogateSpec and must assume we're fitting all metrics - if Keys.ONLY_SURROGATE in self._surrogates.keys(): - surrogate = self._surrogates[Keys.ONLY_SURROGATE] - surrogate.model_options.update(additional_model_inputs) - surrogate.fit( - datasets=datasets, - search_space_digest=search_space_digest, - candidate_metadata=candidate_metadata, - state_dict=( - state_dicts.get(Keys.ONLY_SURROGATE) if state_dicts else None - ), - refit=refit, - ) - self._output_order = list(range(len(outcome_names))) - return - - # Step 1. Initialize a Surrogate for every SurrogateSpec - self._surrogates = { - label: Surrogate( - # if None, Surrogate will autoset class per outcome at construct time - botorch_model_class=spec.botorch_model_class, - model_options=spec.botorch_model_kwargs, - mll_class=spec.mll_class, - mll_options=spec.mll_kwargs, - covar_module_class=spec.covar_module_class, - covar_module_options=spec.covar_module_kwargs, - likelihood_class=spec.likelihood_class, - likelihood_options=spec.likelihood_kwargs, - input_transform_classes=spec.input_transform_classes, - input_transform_options=spec.input_transform_options, - outcome_transform_classes=spec.outcome_transform_classes, - outcome_transform_options=spec.outcome_transform_options, - allow_batched_models=spec.allow_batched_models, - ) - for label, spec in self.surrogate_specs.items() - } - - # Step 1.5. If any outcomes are not explicitly assigned to a Surrogate, create - # a new Surrogate for all these outcomes (which will autoset its botorch model - # class per outcome) UNLESS there is only one SurrogateSpec with no outcomes - # assigned to it, in which case that will be used for all outcomes. - assigned_outcome_names = { - item - for sublist in [spec.outcomes for spec in self.surrogate_specs.values()] - for item in sublist - } - unassigned_outcome_names = [ - name for name in outcome_names if name not in assigned_outcome_names - ] - if len(unassigned_outcome_names) > 0 and len(self.surrogates) != 1: - self._surrogates[Keys.AUTOSET_SURROGATE] = Surrogate() - - # Step 2. Fit each Surrogate iteratively using its assigned outcomes - for label, surrogate in self.surrogates.items(): - if label == Keys.AUTOSET_SURROGATE or len(self.surrogates) == 1: - subset_outcome_names = unassigned_outcome_names + # If a surrogate has not been constructed, construct it. + if self._surrogate is None: + if (spec := self.surrogate_spec) is not None: + self._surrogate = Surrogate( + botorch_model_class=spec.botorch_model_class, + model_options=spec.botorch_model_kwargs, + mll_class=spec.mll_class, + mll_options=spec.mll_kwargs, + covar_module_class=spec.covar_module_class, + covar_module_options=spec.covar_module_kwargs, + likelihood_class=spec.likelihood_class, + likelihood_options=spec.likelihood_kwargs, + input_transform_classes=spec.input_transform_classes, + input_transform_options=spec.input_transform_options, + outcome_transform_classes=spec.outcome_transform_classes, + outcome_transform_options=spec.outcome_transform_options, + allow_batched_models=spec.allow_batched_models, + ) else: - subset_outcome_names = self.surrogate_specs[label].outcomes - subset_datasets = get_subset_datasets( - datasets=datasets, subset_outcome_names=subset_outcome_names - ) - surrogate.model_options.update(additional_model_inputs) - surrogate.fit( - datasets=subset_datasets, - search_space_digest=search_space_digest, - candidate_metadata=candidate_metadata, - state_dict=(state_dicts or {}).get(label), - refit=refit, - ) + self._surrogate = Surrogate() - # Step 3. Output order of outcomes must match input order, but now outcomes are - # grouped according to surrogate. Compute the permutation from surrogate order - # to input ordering. - surrogate_order = [] - for surrogate in self.surrogates.values(): - surrogate_order.extend(surrogate.outcomes) - self._output_order = list( - np.argsort([outcome_names.index(name) for name in surrogate_order]) + # Fit the surrogate. + self.surrogate.model_options.update(additional_model_inputs) + self.surrogate.fit( + datasets=datasets, + search_space_digest=search_space_digest, + candidate_metadata=candidate_metadata, + state_dict=state_dict, + refit=refit, ) - def predict(self, X: Tensor) -> tuple[Tensor, Tensor]: + def predict( + self, X: Tensor, use_posterior_predictive: bool = False + ) -> tuple[Tensor, Tensor]: """Predicts, potentially from multiple surrogates. - If predictions are from multiple surrogates, will stitch outputs together - in same order as input datasets, using self.output_order. - Args: X: (n x d) Tensor of input locations. + use_posterior_predictive: A boolean indicating if the predictions + should be from the posterior predictive (i.e. including + observation noise). Returns: Tuple of tensors: (n x m) mean, (n x m x m) covariance. """ - if len(self.surrogates) == 1: - return self.surrogate.predict(X=X) - fs, covs = [], [] - for surrogate in self.surrogates.values(): - f, cov = surrogate.predict(X=X) - fs.append(f) - covs.append(cov) - f = torch.cat(fs, dim=-1) - cov = torch.zeros( - f.shape[0], f.shape[1], f.shape[1], dtype=X.dtype, device=X.device - ) - i = 0 - for cov_i in covs: - d = cov_i.shape[-1] - cov[:, i : (i + d), i : (i + d)] = cov_i - i += d - # Permute from surrogate order to input ordering - f = f[:, self.output_order] - cov = cov[:, :, self.output_order][:, self.output_order, :] - return f, cov - - def predict_from_surrogate( - self, - surrogate_label: str, - X: Tensor, - use_posterior_predictive: bool = False, - ) -> tuple[Tensor, Tensor]: - """Predict from the Surrogate with the given label.""" - return self.surrogates[surrogate_label].predict( + return self.surrogate.predict( X=X, use_posterior_predictive=use_posterior_predictive ) @@ -464,7 +341,6 @@ def _get_gen_metadata_from_acqf( return gen_metadata @copy_doc(TorchModel.best_point) - @single_surrogate_only def best_point( self, search_space_digest: SearchSpaceDigest, @@ -502,46 +378,21 @@ def cross_validate( use_posterior_predictive: bool = False, **additional_model_inputs: Any, ) -> tuple[Tensor, Tensor]: - # Will fail if metric_names exist across multiple models - metric_names = sum((ds.outcome_names for ds in datasets), []) - surrogate_labels = ( - [ - label - for label, surrogate in self.surrogates.items() - if any(metric in surrogate.outcomes for metric in metric_names) - ] - if len(self.surrogates) > 1 - else [*self.surrogates.keys()] - ) - if len(surrogate_labels) != 1: - raise UserInputError( - "May not cross validate multiple Surrogates at once. Please input " - f"datasets that exist on one Surrogate. {metric_names} spans " - f"{surrogate_labels}" - ) - surrogate_label = surrogate_labels[0] - - current_surrogates = self.surrogates + current_surrogate = self.surrogate # If we should be refitting but not warm-starting the refit, set - # `state_dicts` to None to avoid loading it. - state_dicts = ( + # `state_dict` to None to avoid loading it. + state_dict = ( None if self.refit_on_cv and not self.warm_start_refit - else { - label: deepcopy(checked_cast(OrderedDict, surrogate.model.state_dict())) - for label, surrogate in current_surrogates.items() - } + else current_surrogate.model.state_dict() ) - # Temporarily set `_surrogates` to cloned surrogates to set - # the training data on cloned surrogates to train set and + # Temporarily set `_surrogate` to cloned surrogate to set + # the training data on cloned surrogate to train set and # use it to predict the test point. - surrogate_clones = { - label: surrogate.clone_reset() - for label, surrogate in self.surrogates.items() - } - self._surrogates = surrogate_clones - # Remove the robust_digest since we do not want to use perturbations here. + self._surrogate = current_surrogate.clone_reset() + + # Remove the `robust_digest` since we do not want to use perturbations here. search_space_digest = dataclasses.replace( search_space_digest, robust_digest=None, @@ -551,12 +402,13 @@ def cross_validate( self.fit( datasets=datasets, search_space_digest=search_space_digest, - state_dicts=state_dicts, + # pyre-fixme [6]: state_dict() has a generic dict[str, Any] return type + # but it is actually an OrderedDict[str, Tensor]. + state_dict=state_dict, refit=self.refit_on_cv, **additional_model_inputs, ) - X_test_prediction = self.predict_from_surrogate( - surrogate_label=surrogate_label, + X_test_prediction = self.predict( X=X_test, use_posterior_predictive=use_posterior_predictive, ) @@ -564,7 +416,7 @@ def cross_validate( # Reset the surrogates back to this model's surrogate, make # sure the cloned surrogate doesn't stay around if fit or # predict fail. - self._surrogates = current_surrogates + self._surrogate = current_surrogate return X_test_prediction @property @@ -572,35 +424,14 @@ def dtype(self) -> torch.dtype: """Torch data type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ - dtypes = { - label: surrogate.dtype for label, surrogate in self.surrogates.items() - } - - dtypes_list = list(dtypes.values()) - if dtypes_list.count(dtypes_list[0]) != len(dtypes_list): - raise NotImplementedError( - f"Expected all Surrogates to have same dtype, found {dtypes}" - ) - - return dtypes_list[0] + return self.surrogate.dtype @property def device(self) -> torch.device: """Torch device type of the tensors in the training data used in the model, of which this ``Acquisition`` is a subcomponent. """ - - devices = { - label: surrogate.device for label, surrogate in self.surrogates.items() - } - - devices_list = list(devices.values()) - if devices_list.count(devices_list[0]) != len(devices_list): - raise NotImplementedError( - f"Expected all Surrogates to have same device, found {devices}" - ) - - return devices_list[0] + return self.surrogate.device def _instantiate_acquisition( self, @@ -633,14 +464,12 @@ def _instantiate_acquisition( options=acq_options, ) - @single_surrogate_only # pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters. def feature_importances(self) -> np.ndarray: """Compute feature importances from the model. - Caveat: This assumes the following: - 1. There is a single surrogate model (potentially a `ModelList`). - 2. We can get model lengthscales from `covar_module.base_kernel.lengthscale` + This assumes that we can get model lengthscales from either + ``covar_module.base_kernel.lengthscale`` or ``covar_module.lengthscale``. Returns: The feature importances as a numpy array of size len(metrics) x 1 x dim @@ -659,19 +488,3 @@ def search_space_digest(self) -> SearchSpaceDigest: @search_space_digest.setter def search_space_digest(self, value: SearchSpaceDigest) -> None: raise RuntimeError("Setting search_space_digest manually is disallowed.") - - @property - def outcomes_by_surrogate_label(self) -> dict[str, list[str]]: - """Returns a dictionary mapping from surrogate label to a list of outcomes.""" - outcomes_by_surrogate_label = {} - for k, v in self.surrogates.items(): - outcomes_by_surrogate_label[k] = v.outcomes - return outcomes_by_surrogate_label - - @property - def output_order(self) -> list[int]: - if self._output_order is None: - raise RuntimeError( - "`output_order` is not initialized. Must `fit` the model first." - ) - return self._output_order diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index 7280ae22524..fc1a836a56f 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -202,8 +202,8 @@ def test_init(self) -> None: self.assertTrue(mdl2.refit_on_cv) self.assertFalse(mdl2.warm_start_refit) - def test_surrogates_property(self) -> None: - self.assertEqual(self.surrogate, list(self.model.surrogates.values())[0]) + def test_surrogate_property(self) -> None: + self.assertIs(self.surrogate, self.model.surrogate) def test_Xs_property(self) -> None: self.model.fit( @@ -217,10 +217,6 @@ def test_Xs_property(self) -> None: self.model.Xs[0].equal(torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]])) ) - with self.assertRaisesRegex(NotImplementedError, "Xs not implemented"): - self.model._surrogates = {"foo": Surrogate(), "bar": Surrogate()} - self.model.Xs - def test_dtype(self) -> None: self.model.fit( datasets=self.block_design_training_data, @@ -292,7 +288,7 @@ def test__construct__converts_non_block(self, _: Mock) -> None: def test__construct(self) -> None: """Test autoset.""" - self.model._surrogates = {} + self.model._surrogate = None with mock.patch( f"{SURROGATE_PATH}.choose_model_class", wraps=choose_model_class ) as mock_choose_model_class: @@ -310,7 +306,7 @@ def test__construct(self) -> None: @mock.patch(f"{SURROGATE_PATH}.Surrogate._construct_model") def test_fit(self, mock_fit: Mock) -> None: # If surrogate is not yet set, initialize it with dispatcher functions. - self.model._surrogates = {} + self.model._surrogate = None with self.assertRaisesRegex(RuntimeError, "is not initialized. Must `fit`"): self.model.search_space_digest # can't access before fit @@ -325,7 +321,6 @@ def test_fit(self, mock_fit: Mock) -> None: self.assertIsInstance(self.model.search_space_digest, SearchSpaceDigest) self.assertEqual(self.model.search_space_digest, self.mf_search_space_digest) - self.assertEqual(self.model.output_order, [0]) # Since we want to refit on updates but not warm start refit, we clear the # state dict. @@ -337,53 +332,31 @@ def test_fit(self, mock_fit: Mock) -> None: refit=True, ) - labels_to_outcomes = self.model.outcomes_by_surrogate_label - self.assertIsInstance(labels_to_outcomes, dict) - self.assertEqual(len(labels_to_outcomes), 1) - self.assertEqual(next(iter(labels_to_outcomes.values())), self.metric_names) - @mock.patch(f"{SURROGATE_PATH}.Surrogate.predict") def test_predict(self, mock_predict: Mock) -> None: self.model.predict(X=self.X_test) - mock_predict.assert_called_with(X=self.X_test) + mock_predict.assert_called_with(X=self.X_test, use_posterior_predictive=False) + self.model.predict(X=self.X_test, use_posterior_predictive=True) + mock_predict.assert_called_with(X=self.X_test, use_posterior_predictive=True) - @fast_botorch_optimize - def test_multiple_surrogates(self) -> None: + def test_with_surrogate_specs_input(self) -> None: + spec1 = SurrogateSpec( + botorch_model_class=SingleTaskGP, + outcomes=["y1", "y3"], + ) surrogate_specs = { - "Vanilla": SurrogateSpec( - botorch_model_class=SingleTaskGP, - outcomes=["y1", "y3"], - ), + "Vanilla": spec1, "Bayesian": SurrogateSpec( botorch_model_class=SaasFullyBayesianSingleTaskGP, outcomes=["y2"], ), } - model = BoTorchModel(surrogate_specs=surrogate_specs) - model.fit( - datasets=self.moo_training_data, - search_space_digest=self.search_space_digest, - candidate_metadata=self.candidate_metadata, - ) - self.assertEqual(list(model.surrogates.keys()), list(surrogate_specs.keys())) - self.assertEqual(model.output_order, [0, 2, 1]) - f1, cov1 = model.surrogates["Vanilla"].predict(self.X_test) - f2, cov2 = model.surrogates["Bayesian"].predict(self.X_test) - n = self.X_test.shape[0] - true_f = torch.zeros(n, 3, dtype=self.X_test.dtype, device=self.X_test.device) - true_f[:, 0] = f1[:, 0] - true_f[:, 2] = f1[:, 1] - true_f[:, 1] = f2[:, 0] - true_cov = torch.zeros( - n, 3, 3, dtype=self.X_test.dtype, device=self.X_test.device - ) - for i in range(2): - true_cov[i, 0, 0] = cov1[i, 0, 0] - true_cov[i, 2, 2] = cov1[i, 1, 1] - true_cov[i, 1, 1] = cov2[i, 0, 0] - f_joint, cov_joint = model.predict(self.X_test) - self.assertTrue(torch.allclose(true_f, f_joint)) - self.assertTrue(torch.allclose(true_cov, cov_joint)) + with self.assertRaisesRegex(DeprecationWarning, "Support for multiple"): + BoTorchModel(surrogate_specs=surrogate_specs) + + with self.assertWarnsRegex(DeprecationWarning, "surrogate_specs"): + model = BoTorchModel(surrogate_specs={"s": spec1}) + self.assertIs(model.surrogate_spec, spec1) @mock.patch(f"{MODEL_PATH}.BoTorchModel.fit") def test_cross_validate(self, mock_fit: Mock) -> None: @@ -393,7 +366,7 @@ def test_cross_validate(self, mock_fit: Mock) -> None: candidate_metadata=self.candidate_metadata, ) - old_surrogate = self.model.surrogates[Keys.ONLY_SURROGATE] + old_surrogate = self.model.surrogate old_surrogate._model = mock.MagicMock() old_surrogate._model.state_dict.return_value = OrderedDict({"key": "val"}) @@ -418,41 +391,30 @@ def test_cross_validate(self, mock_fit: Mock) -> None: mock_predict.assert_called_once() # Check correct X_test. - self.assertTrue( - torch.equal( - mock_predict.call_args_list[-1][1].get("X"), - self.X_test, - ), - ) + kwargs = mock_predict.call_args.kwargs + self.assertTrue(torch.equal(kwargs["X"], self.X_test)) self.assertIs( - mock_predict.call_args_list[-1][1]["use_posterior_predictive"], - use_posterior_predictive, + kwargs["use_posterior_predictive"], use_posterior_predictive ) # Check that surrogate is reset back to `old_surrogate` at the # end of cross-validation. - self.assertTrue(self.model.surrogates[Keys.ONLY_SURROGATE] is old_surrogate) + self.assertTrue(self.model.surrogate is old_surrogate) expected_state_dict = ( None if refit_on_cv and not warm_start_refit - else self.model.surrogates[Keys.ONLY_SURROGATE].model.state_dict() + else self.model.surrogate.model.state_dict() ) # Check correct `refit` and `state_dict` values. - self.assertEqual(mock_fit.call_args_list[-1][1].get("refit"), refit_on_cv) + kwargs = mock_fit.call_args.kwargs + self.assertEqual(kwargs["refit"], refit_on_cv) if expected_state_dict is None: - self.assertIsNone( - mock_fit.call_args_list[-1][1].get("state_dict"), - expected_state_dict, - ) + self.assertIsNone(kwargs["state_dict"], expected_state_dict) else: self.assertEqual( - mock_fit.call_args_list[-1][1] - .get("state_dicts") - .get(Keys.ONLY_SURROGATE) - .keys(), - expected_state_dict.keys(), + kwargs["state_dict"].keys(), expected_state_dict.keys() ) @fast_botorch_optimize @@ -493,7 +455,7 @@ def _test_gen( acquisition_class=Acquisition, acquisition_options=self.acquisition_options, ) - model.surrogates[Keys.ONLY_SURROGATE].fit( + model.surrogate.fit( datasets=self.block_design_training_data, search_space_digest=search_space_digest, ) @@ -613,7 +575,7 @@ def test_feature_importances(self) -> None: acquisition_class=Acquisition, acquisition_options=self.acquisition_options, ) - model.surrogates[Keys.ONLY_SURROGATE].fit( + model.surrogate.fit( datasets=self.block_design_training_data, search_space_digest=SearchSpaceDigest(feature_names=[], bounds=[]), ) @@ -672,17 +634,10 @@ def test_feature_importances(self) -> None: ValueError, "BoTorch `Model` has not yet been constructed" ): model.feature_importances() - # Test unsupported surrogate - model._surrogates = {"vanilla": None, "saas": None} - with self.assertRaisesRegex( - NotImplementedError, - "feature_importances not implemented for multi-surrogate case", - ): - model.feature_importances() @fast_botorch_optimize def test_best_point(self) -> None: - self.model._surrogates = {} + self.model._surrogate = None self.model.fit( datasets=self.block_design_training_data, search_space_digest=self.mf_search_space_digest, @@ -723,7 +678,7 @@ def test_evaluate_acquisition_function( acquisition_class=Acquisition, acquisition_options=self.acquisition_options, ) - model.surrogates[Keys.ONLY_SURROGATE].fit( + model.surrogate.fit( datasets=self.block_design_training_data, search_space_digest=SearchSpaceDigest( feature_names=[], @@ -750,11 +705,9 @@ def test_surrogate_model_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: model = BoTorchModel( - surrogate_specs={ - "name": SurrogateSpec( - botorch_model_kwargs={"some_option": "some_value"} - ) - } + surrogate_spec=SurrogateSpec( + botorch_model_kwargs={"some_option": "some_value"} + ) ) model.fit( datasets=self.non_block_design_training_data, @@ -783,9 +736,7 @@ def test_surrogate_model_options_propagation( def test_surrogate_options_propagation( self, _m1: Mock, _m2: Mock, mock_init: Mock ) -> None: - model = BoTorchModel( - surrogate_specs={"name": SurrogateSpec(allow_batched_models=False)} - ) + model = BoTorchModel(surrogate_spec=SurrogateSpec(allow_batched_models=False)) model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, @@ -819,11 +770,8 @@ def test_model_list_choice(self, _) -> None: # , mock_extract_training_data): search_space_digest=self.mf_search_space_digest, candidate_metadata=self.candidate_metadata, ) - self.assertEqual(model.output_order, [0, 1]) # A model list should be chosen, since Xs are not all the same. - model_list = checked_cast( - ModelList, model.surrogates[Keys.AUTOSET_SURROGATE].model - ) + model_list = checked_cast(ModelList, model.surrogate.model) for submodel in model_list.models: # There are fidelity features and nonempty Yvars, so # MFGP should be chosen. @@ -853,9 +801,7 @@ def test_MOO(self, _) -> None: search_space_digest=self.search_space_digest, candidate_metadata=self.candidate_metadata, ) - self.assertIsInstance( - model.surrogates[Keys.AUTOSET_SURROGATE].model, SingleTaskGP - ) + self.assertIsInstance(model.surrogate.model, SingleTaskGP) subset_outcome_constraints = ( # model is subset since last output is not used self.moo_outcome_constraints[0][:, :2], diff --git a/ax/models/torch/tests/test_sebo.py b/ax/models/torch/tests/test_sebo.py index 81d68861af8..94af813276d 100644 --- a/ax/models/torch/tests/test_sebo.py +++ b/ax/models/torch/tests/test_sebo.py @@ -52,7 +52,7 @@ def setUp(self) -> None: super().setUp() tkwargs: dict[str, Any] = {"dtype": torch.double} self.botorch_model_class = SingleTaskGP - self.surrogates = Surrogate(botorch_model_class=self.botorch_model_class) + self.surrogate = Surrogate(botorch_model_class=self.botorch_model_class) self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs) self.target_point = torch.tensor([1.0, 1.0, 1.0], **tkwargs) self.Y = torch.tensor([[3.0], [4.0]], **tkwargs) @@ -67,7 +67,7 @@ def setUp(self) -> None: bounds=[(0.0, 10.0), (0.0, 10.0), (0.0, 10.0)], target_values={2: 1.0}, ) - self.surrogates.fit( + self.surrogate.fit( datasets=self.training_data, search_space_digest=self.search_space_digest, ) @@ -119,7 +119,7 @@ def get_acquisition_function( ) -> SEBOAcquisition: return SEBOAcquisition( botorch_acqf_class=qNoisyExpectedHypervolumeImprovement, - surrogate=self.surrogates, + surrogate=self.surrogate, search_space_digest=self.search_space_digest, torch_opt_config=dataclasses.replace( torch_opt_config or self.torch_opt_config, diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index ef96c10765a..cc16ccc890d 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -56,7 +56,6 @@ from ax.models.winsorization_config import WinsorizationConfig from ax.storage.botorch_modular_registry import CLASS_TO_REGISTRY from ax.storage.transform_registry import TRANSFORM_REGISTRY -from ax.utils.common.constants import Keys from ax.utils.common.serialization import serialize_init_args from ax.utils.common.typeutils_torch import torch_type_to_str from botorch.models.transforms.input import ChainedInputTransform, InputTransform @@ -530,14 +529,8 @@ def botorch_model_to_dict(model: BoTorchModel) -> dict[str, Any]: "__type": model.__class__.__name__, "acquisition_class": model.acquisition_class, "acquisition_options": model.acquisition_options or {}, - "surrogate": ( - model._surrogates[Keys.ONLY_SURROGATE] - if Keys.ONLY_SURROGATE in model._surrogates - else None - ), - "surrogate_specs": ( - model.surrogate_specs if len(model.surrogate_specs) > 0 else None - ), + "surrogate": (model._surrogate if model.surrogate_spec is None else None), + "surrogate_spec": model.surrogate_spec, "botorch_acqf_class": model._botorch_acqf_class, "refit_on_cv": model.refit_on_cv, "warm_start_refit": model.warm_start_refit, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 338c3a5295c..b0868eb3d18 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -58,6 +58,7 @@ get_batch_trial, get_botorch_model, get_botorch_model_with_default_acquisition_class, + get_botorch_model_with_surrogate_spec, get_botorch_model_with_surrogate_specs, get_branin_data, get_branin_experiment, @@ -141,6 +142,7 @@ ("BenchmarkResult", get_benchmark_result), ("BoTorchModel", get_botorch_model), ("BoTorchModel", get_botorch_model_with_default_acquisition_class), + ("BoTorchModel", get_botorch_model_with_surrogate_spec), ("BoTorchModel", get_botorch_model_with_surrogate_specs), ("BraninMetric", get_branin_metric), ("ChainedInputTransform", get_chained_input_transform), @@ -799,3 +801,49 @@ def test_SobolQMCNormalSampler(self) -> None: self.assertIsInstance(sampler_loaded, SobolQMCNormalSampler) self.assertEqual(sampler.sample_shape, sampler_loaded.sample_shape) self.assertEqual(sampler.seed, sampler_loaded.seed) + + def test_mbm_backwards_compatibility(self) -> None: + # This is json of get_botorch_model_with_surrogate_specs() before D64875988. + object_json = { + "__type": "BoTorchModel", + "acquisition_class": { + "__type": "Type[Acquisition]", + "index": "Acquisition", + "class": ( + "" + ), + }, + "acquisition_options": {}, + "surrogate": None, + "surrogate_specs": { + "name": { + "__type": "SurrogateSpec", + "botorch_model_class": None, + "botorch_model_kwargs": {"some_option": "some_value"}, + "mll_class": { + "__type": "Type[MarginalLogLikelihood]", + "index": "ExactMarginalLogLikelihood", + "class": ( + "" + ), + }, + "mll_kwargs": {}, + "covar_module_class": None, + "covar_module_kwargs": None, + "likelihood_class": None, + "likelihood_kwargs": None, + "input_transform_classes": None, + "input_transform_options": None, + "outcome_transform_classes": None, + "outcome_transform_options": None, + "allow_batched_models": True, + "outcomes": [], + } + }, + "botorch_acqf_class": None, + "refit_on_cv": False, + "warm_start_refit": True, + } + expected_object = get_botorch_model_with_surrogate_spec() + self.assertEqual(object_from_json(object_json), expected_object) diff --git a/ax/utils/sensitivity/sobol_measures.py b/ax/utils/sensitivity/sobol_measures.py index 6872b285b77..c64d9f2da1c 100644 --- a/ax/utils/sensitivity/sobol_measures.py +++ b/ax/utils/sensitivity/sobol_measures.py @@ -8,13 +8,10 @@ import itertools from collections.abc import Callable from copy import deepcopy - from typing import Any import numpy as np - import torch - from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch import BotorchModel from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel @@ -941,12 +938,12 @@ def _get_torch_model( """ if not isinstance(model_bridge, TorchModelBridge): raise NotImplementedError( - f"{type(model_bridge) = }, but only TorchModelBridge is supported." + f"{type(model_bridge)=}, but only TorchModelBridge is supported." ) model = model_bridge.model # should be of type TorchModel if not (isinstance(model, BotorchModel) or isinstance(model, ModularBoTorchModel)): raise NotImplementedError( - f"{type(model_bridge.model) = }, but only " + f"{type(model_bridge.model)=}, but only " "Union[BotorchModel, ModularBoTorchModel] is supported." ) return model @@ -971,19 +968,18 @@ def _get_model_per_metric( ) return [gp_model.models[i] for i in model_idx] else: # isinstance(model, ModularBoTorchModel): + surrogate = model.surrogate + outcomes = surrogate.outcomes model_list = [] for m in metrics: # for each metric, find a corresponding surrogate - for label, outcomes in model.outcomes_by_surrogate_label.items(): - if m in outcomes: - i = outcomes.index(m) - metric_model = model.surrogates[label].model - # since model is a ModularBoTorchModel, metric_model will be a - # `botorch.models.model.Model` object, which have the `num_outputs` - # property and `subset_outputs` method. - if metric_model.num_outputs > 1: # subset to relevant output - metric_model = metric_model.subset_output([i]) - model_list.append(metric_model) - continue # found surrogate for `m`, so we can move on to next `m`. + i = outcomes.index(m) + metric_model = surrogate.model + # since model is a ModularBoTorchModel, metric_model will be a + # `botorch.models.model.Model` object, which have the `num_outputs` + # property and `subset_outputs` method. + if metric_model.num_outputs > 1: # subset to relevant output + metric_model = metric_model.subset_output([i]) + model_list.append(metric_model) return model_list diff --git a/ax/utils/sensitivity/tests/test_sensitivity.py b/ax/utils/sensitivity/tests/test_sensitivity.py index 0364cd7a723..320a1fe659a 100644 --- a/ax/utils/sensitivity/tests/test_sensitivity.py +++ b/ax/utils/sensitivity/tests/test_sensitivity.py @@ -65,9 +65,7 @@ class SensitivityAnalysisTest(TestCase): def setUp(self) -> None: super().setUp() self.model = get_modelbridge().model.model - self.saas_model = ( - get_modelbridge(saasbo=True).model.surrogates["SAASBO_Surrogate"].model - ) + self.saas_model = get_modelbridge(saasbo=True).model.surrogate.model def test_DgsmGpMean(self) -> None: bounds = torch.tensor([(0.0, 1.0) for _ in range(2)]).t() diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 4c57e1d26cb..79cf5c13996 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -2217,6 +2217,12 @@ def get_botorch_model_with_surrogate_specs() -> BoTorchModel: ) +def get_botorch_model_with_surrogate_spec() -> BoTorchModel: + return BoTorchModel( + surrogate_spec=SurrogateSpec(botorch_model_kwargs={"some_option": "some_value"}) + ) + + def get_surrogate() -> Surrogate: return Surrogate( botorch_model_class=get_model_type(),