From e227828df92cedf1ab6cbcd4e618eeabdfe615de Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Fri, 18 Oct 2024 11:22:45 -0700 Subject: [PATCH] Remove support(?) for multiple surrogates from Acquisition Summary: The current implementation of `Acqusition.__init__` technically supports multiple surrogates. It combines the models from each surrogate into a `ModelDict` before passing them down to acquisition input constructors. `ModelDict` was added some time ago with the goal of using it to support failure aware BO, but those methods never got implemented and there is no current use case that supports it. In its current form, the support for multiple surrugates is completely superficial and does not bring much value. It requires complex handling of arguments like `X_pending` and `X_observed`, which itself comes with TODOs to fix or improve it. This diff removes support for multiple surrogates from `Acquisition` and cleans up some of the complicated argument handling that was necessitated by it. If we decide to support multiple surrogates again at a later date, we can do so with a better thought out design and implement it more cleanly. Differential Revision: D64610244 --- .../torch/botorch_modular/acquisition.py | 176 +++++------------- ax/models/torch/tests/test_acquisition.py | 18 +- 2 files changed, 60 insertions(+), 134 deletions(-) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 7581ab9b178..9dc4622863b 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -8,7 +8,6 @@ from __future__ import annotations -import functools import operator import warnings from collections.abc import Callable @@ -19,14 +18,11 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UnsupportedError from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate -from ax.models.torch.botorch_modular.utils import ( - _tensor_difference, - get_post_processing_func, -) +from ax.models.torch.botorch_modular.utils import get_post_processing_func from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds from ax.models.torch.utils import ( _get_X_pending_and_observed, @@ -43,7 +39,7 @@ from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform from botorch.acquisition.risk_measures import RiskMeasureMCObjective -from botorch.models.model import Model, ModelDict +from botorch.models.model import Model from botorch.optim.optimize import ( optimize_acqf, optimize_acqf_discrete, @@ -51,6 +47,7 @@ optimize_acqf_mixed, ) from botorch.utils.constraints import get_outcome_constraint_transforms +from pyre_extensions import none_throws from torch import Tensor @@ -72,6 +69,7 @@ class Acquisition(Base): Args: surrogates: Dict of name => Surrogate model pairs, with which this acquisition function will be used. + NOTE: Only a single surrogate is currently supported! search_space_digest: A SearchSpaceDigest object containing metadata about the search space (e.g. bounds, parameter types). torch_opt_config: A TorchOptConfig object containing optimization @@ -90,89 +88,32 @@ class Acquisition(Base): def __init__( self, - # If using multiple Surrogates, must label primary Surrogate (typically the - # regression Surrogate) Keys.PRIMARY_SURROGATE surrogates: dict[str, Surrogate], search_space_digest: SearchSpaceDigest, torch_opt_config: TorchOptConfig, botorch_acqf_class: type[AcquisitionFunction], options: dict[str, Any] | None = None, ) -> None: - self.surrogates = surrogates - self.options = options or {} - - # Compute pending and observed points for each surrogate - Xs_pending_and_observed = { - name: _get_X_pending_and_observed( - Xs=surrogate.Xs, - objective_weights=torch_opt_config.objective_weights, - bounds=search_space_digest.bounds, - pending_observations=torch_opt_config.pending_observations, - outcome_constraints=torch_opt_config.outcome_constraints, - linear_constraints=torch_opt_config.linear_constraints, - fixed_features=torch_opt_config.fixed_features, + if len(surrogates) > 1: + raise UnsupportedError( + "The `Acquisition` class currently only supports a single surrogate." ) - for name, surrogate in self.surrogates.items() - } - Xs_pending_list = [ - Xs_pending - for Xs_pending, _ in Xs_pending_and_observed.values() - if Xs_pending is not None - ] - unique_Xs_pending = ( - torch.unique( - input=torch.cat( - tensors=Xs_pending_list, - dim=0, - ), - dim=0, - ) - if len(Xs_pending_list) > 0 - else None - ) + self.surrogates = surrogates + self.options = options or {} + primary_surrogate = next(iter(self.surrogates.values())) - # This tensor may have some Xs that are also in pending (because they are - # observed for some models but not others) - Xs_observed_maybe_pending_list = [ - Xs_observed - for _, Xs_observed in Xs_pending_and_observed.values() - if Xs_observed is not None - ] - unique_Xs_observed_maybe_pending = ( - torch.unique( - input=torch.cat( - tensors=Xs_observed_maybe_pending_list, - dim=0, - ), - dim=0, - ) - if len(Xs_observed_maybe_pending_list) > 0 - else None + # Extract pending and observed points. + X_pending, X_observed = _get_X_pending_and_observed( + Xs=primary_surrogate.Xs, + objective_weights=torch_opt_config.objective_weights, + bounds=search_space_digest.bounds, + pending_observations=torch_opt_config.pending_observations, + outcome_constraints=torch_opt_config.outcome_constraints, + linear_constraints=torch_opt_config.linear_constraints, + fixed_features=torch_opt_config.fixed_features, ) - # If a point is pending on any model do not count it as observed. - # Do this by stacking pending on top of observed, filtering repeats, then - # removing pending points. - # TODO[sdaulton] Is this a sound approach? Should we be doing something more - # sophisticated here? - if unique_Xs_pending is None: - unique_Xs_observed = unique_Xs_observed_maybe_pending - elif unique_Xs_observed_maybe_pending is None: - unique_Xs_observed = None - else: - unique_Xs_observed = _tensor_difference( - A=unique_Xs_pending, B=unique_Xs_observed_maybe_pending - ) - - if torch.numel(unique_Xs_observed_maybe_pending) != torch.numel( - unique_Xs_observed - ): - logger.warning( - "Encountered Xs pending for some Surrogates but observed for " - "others. Considering these points to be pending." - ) - # Store objective thresholds for all outcomes (including non-objectives). self._objective_thresholds: Tensor | None = ( torch_opt_config.objective_thresholds @@ -180,22 +121,6 @@ def __init__( self._full_objective_weights: Tensor = torch_opt_config.objective_weights full_outcome_constraints = torch_opt_config.outcome_constraints - # TODO[mpolson64] Handle more elegantly in the future. Since right now we - # only use one objective and posterior_transform this should be fine. - primary_surrogate = ( - self.surrogates[Keys.PRIMARY_SURROGATE] - if len(self.surrogates) > 1 - else next(iter(self.surrogates.values())) - ) - - primary_Xs_pending, primary_Xs_observed = Xs_pending_and_observed[ - ( - Keys.PRIMARY_SURROGATE - if len(self.surrogates) > 1 - else next(iter(Xs_pending_and_observed.keys())) - ) - ] - # Subset model only to the outcomes we need for the optimization. if self.options.pop(Keys.SUBSET_MODEL, True): subset_model_results = subset_model( @@ -215,22 +140,21 @@ def __init__( outcome_constraints = torch_opt_config.outcome_constraints objective_thresholds = torch_opt_config.objective_thresholds subset_idcs = None - # If objective weights suggest multiple objectives but objective - # thresholds are not specified, infer them using the model that - # has already been subset to avoid re-subsetting it within - # `inter_objective_thresholds`. + + # If MOO and some objective thresholds are not specified, infer them using + # the model that has already been subset to avoid re-subsetting it within + # `infer_objective_thresholds`. if ( - objective_weights.nonzero().numel() > 1 + torch_opt_config.is_moo and ( self._objective_thresholds is None or self._objective_thresholds[torch_opt_config.objective_weights != 0] .isnan() .any() ) - and primary_Xs_observed is not None + and X_observed is not None ): if torch_opt_config.risk_measure is not None: - # TODO[T131759263]: modify the heuristic to support risk measures. raise NotImplementedError( "Objective thresholds must be provided when using risk measures." ) @@ -238,7 +162,7 @@ def __init__( model=model, objective_weights=self._full_objective_weights, outcome_constraints=full_outcome_constraints, - X_observed=primary_Xs_observed, + X_observed=X_observed, subset_idcs=subset_idcs, objective_thresholds=self._objective_thresholds, ) @@ -253,36 +177,24 @@ def __init__( objective_weights=objective_weights, objective_thresholds=objective_thresholds, outcome_constraints=outcome_constraints, - X_observed=primary_Xs_observed, + X_observed=X_observed, risk_measure=torch_opt_config.risk_measure, ) - acqf_model_kwarg = ( - { - "model_dict": ModelDict( - **{ - name: surrogate.model - for name, surrogate in self.surrogates.items() - } - ) - } - if len(self.surrogates) > 1 - else {"model": model} - ) target_fidelities = { k: v for k, v in search_space_digest.target_values.items() if k in search_space_digest.fidelity_features } input_constructor_kwargs = { - "X_baseline": unique_Xs_observed, - "X_pending": unique_Xs_pending, + "model": model, + "X_baseline": X_observed, + "X_pending": X_pending, "objective_thresholds": objective_thresholds, "constraints": get_outcome_constraint_transforms( outcome_constraints=outcome_constraints ), "objective": objective, "posterior_transform": posterior_transform, - **acqf_model_kwarg, **self.options, } @@ -290,20 +202,18 @@ def __init__( input_constructor_kwargs["target_fidelities"] = target_fidelities input_constructor = get_acqf_input_constructor(botorch_acqf_class) - # Handle multi-dataset surrogates - TODO: Improve this - # If there is only one SupervisedDataset return it alone - if ( - len(self.surrogates) == 1 - and len(next(iter(self.surrogates.values())).training_data) == 1 - ): - training_data = next(iter(self.surrogates.values())).training_data[0] + + # Extract the training data from the surrogate. + # If there is a single dataset, this will be the dataset itself. + # If there are multiple datasets, this will be a dict mapping the outcome names + # to the corresponding datasets. + training_data = primary_surrogate.training_data + if len(training_data) == 1: + training_data = training_data[0] else: - tdicts = ( - dict(zip(not_none(surrogate._outcomes), surrogate.training_data)) - for surrogate in self.surrogates.values() + training_data = dict( + zip(none_throws(primary_surrogate._outcomes), training_data) ) - # outcome_name => Dataset - training_data = functools.reduce(lambda x, y: {**x, **y}, tdicts) acqf_inputs = input_constructor( training_data=training_data, @@ -311,8 +221,8 @@ def __init__( **{k: v for k, v in input_constructor_kwargs.items() if v is not None}, ) self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45] - self.X_pending: Tensor | None = unique_Xs_pending - self.X_observed: Tensor | None = unique_Xs_observed + self.X_pending: Tensor | None = X_pending + self.X_observed: Tensor | None = X_observed @property def botorch_acqf_class(self) -> type[AcquisitionFunction]: diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index b9cdd29d31f..cc9338764ee 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -18,7 +18,7 @@ import numpy as np import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import AxWarning, SearchSpaceExhausted +from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UnsupportedError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -776,6 +776,7 @@ def test_init_moo( objective_weights=moo_objective_weights, outcome_constraints=outcome_constraints, objective_thresholds=moo_objective_thresholds, + is_moo=True, ) acquisition = Acquisition( surrogates={"surrogate": self.surrogate}, @@ -860,3 +861,18 @@ def test_init_moo( def test_init_no_X_observed(self) -> None: self.test_init_moo(with_no_X_observed=True, with_outcome_constraints=False) + + def test_init_multiple_surrogates(self) -> None: + with self.assertRaisesRegex( + UnsupportedError, "currently only supports a single surrogate" + ): + Acquisition( + surrogates={ + "surrogate_1": self.surrogate, + "surrogate_2": self.surrogate, + }, + search_space_digest=self.search_space_digest, + torch_opt_config=self.torch_opt_config, + botorch_acqf_class=self.botorch_acqf_class, + options=self.options, + )