Skip to content

Commit

Permalink
Remove support(?) for multiple surrogates from Acquisition (#2916)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2916

The current implementation of `Acqusition.__init__` technically supports multiple surrogates. It combines the models from each surrogate into a `ModelDict` before passing them down to acquisition input constructors. `ModelDict` was added some time ago with the goal of using it to support failure aware BO, but those methods never got implemented and there is no current use case that supports it.

In its current form, the support for multiple surrugates is completely superficial and does not bring much value. It requires complex handling of arguments like `X_pending` and `X_observed`, which itself comes with TODOs to fix or improve it.

This diff removes support for multiple surrogates from `Acquisition` and cleans up some of the complicated argument handling that was necessitated by it. If we decide to support multiple surrogates again at a later date, we can do so with a better thought out design and implement it more cleanly.

Reviewed By: lena-kashtelyan

Differential Revision: D64610244

fbshipit-source-id: 70479e4ca7a5c7fc99ea36ed7c95261d1a2fd4a9
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 18, 2024
1 parent 6bdf0a5 commit 42be53a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 135 deletions.
176 changes: 43 additions & 133 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from __future__ import annotations

import functools
import operator
import warnings
from collections.abc import Callable
Expand All @@ -19,14 +18,11 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UnsupportedError
from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.utils import (
_tensor_difference,
get_post_processing_func,
)
from ax.models.torch.botorch_modular.utils import get_post_processing_func
from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds
from ax.models.torch.utils import (
_get_X_pending_and_observed,
Expand All @@ -43,14 +39,15 @@
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.models.model import Model, ModelDict
from botorch.models.model import Model
from botorch.optim.optimize import (
optimize_acqf,
optimize_acqf_discrete,
optimize_acqf_discrete_local_search,
optimize_acqf_mixed,
)
from botorch.utils.constraints import get_outcome_constraint_transforms
from pyre_extensions import none_throws
from torch import Tensor


Expand All @@ -72,6 +69,7 @@ class Acquisition(Base):
Args:
surrogates: Dict of name => Surrogate model pairs, with which this acquisition
function will be used.
NOTE: Only a single surrogate is currently supported!
search_space_digest: A SearchSpaceDigest object containing metadata
about the search space (e.g. bounds, parameter types).
torch_opt_config: A TorchOptConfig object containing optimization
Expand All @@ -90,112 +88,39 @@ class Acquisition(Base):

def __init__(
self,
# If using multiple Surrogates, must label primary Surrogate (typically the
# regression Surrogate) Keys.PRIMARY_SURROGATE
surrogates: dict[str, Surrogate],
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
botorch_acqf_class: type[AcquisitionFunction],
options: dict[str, Any] | None = None,
) -> None:
self.surrogates = surrogates
self.options = options or {}

# Compute pending and observed points for each surrogate
Xs_pending_and_observed = {
name: _get_X_pending_and_observed(
Xs=surrogate.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
if len(surrogates) > 1:
raise UnsupportedError(
"The `Acquisition` class currently only supports a single surrogate."
)
for name, surrogate in self.surrogates.items()
}

Xs_pending_list = [
Xs_pending
for Xs_pending, _ in Xs_pending_and_observed.values()
if Xs_pending is not None
]
unique_Xs_pending = (
torch.unique(
input=torch.cat(
tensors=Xs_pending_list,
dim=0,
),
dim=0,
)
if len(Xs_pending_list) > 0
else None
)
self.surrogates = surrogates
self.options = options or {}
primary_surrogate = next(iter(self.surrogates.values()))

# This tensor may have some Xs that are also in pending (because they are
# observed for some models but not others)
Xs_observed_maybe_pending_list = [
Xs_observed
for _, Xs_observed in Xs_pending_and_observed.values()
if Xs_observed is not None
]
unique_Xs_observed_maybe_pending = (
torch.unique(
input=torch.cat(
tensors=Xs_observed_maybe_pending_list,
dim=0,
),
dim=0,
)
if len(Xs_observed_maybe_pending_list) > 0
else None
# Extract pending and observed points.
X_pending, X_observed = _get_X_pending_and_observed(
Xs=primary_surrogate.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
)

# If a point is pending on any model do not count it as observed.
# Do this by stacking pending on top of observed, filtering repeats, then
# removing pending points.
# TODO[sdaulton] Is this a sound approach? Should we be doing something more
# sophisticated here?
if unique_Xs_pending is None:
unique_Xs_observed = unique_Xs_observed_maybe_pending
elif unique_Xs_observed_maybe_pending is None:
unique_Xs_observed = None
else:
unique_Xs_observed = _tensor_difference(
A=unique_Xs_pending, B=unique_Xs_observed_maybe_pending
)

if torch.numel(unique_Xs_observed_maybe_pending) != torch.numel(
unique_Xs_observed
):
logger.warning(
"Encountered Xs pending for some Surrogates but observed for "
"others. Considering these points to be pending."
)

# Store objective thresholds for all outcomes (including non-objectives).
self._objective_thresholds: Tensor | None = (
torch_opt_config.objective_thresholds
)
self._full_objective_weights: Tensor = torch_opt_config.objective_weights
full_outcome_constraints = torch_opt_config.outcome_constraints

# TODO[mpolson64] Handle more elegantly in the future. Since right now we
# only use one objective and posterior_transform this should be fine.
primary_surrogate = (
self.surrogates[Keys.PRIMARY_SURROGATE]
if len(self.surrogates) > 1
else next(iter(self.surrogates.values()))
)

primary_Xs_pending, primary_Xs_observed = Xs_pending_and_observed[
(
Keys.PRIMARY_SURROGATE
if len(self.surrogates) > 1
else next(iter(Xs_pending_and_observed.keys()))
)
]

# Subset model only to the outcomes we need for the optimization.
if self.options.pop(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
Expand All @@ -215,30 +140,29 @@ def __init__(
outcome_constraints = torch_opt_config.outcome_constraints
objective_thresholds = torch_opt_config.objective_thresholds
subset_idcs = None
# If objective weights suggest multiple objectives but objective
# thresholds are not specified, infer them using the model that
# has already been subset to avoid re-subsetting it within
# `inter_objective_thresholds`.

# If MOO and some objective thresholds are not specified, infer them using
# the model that has already been subset to avoid re-subsetting it within
# `infer_objective_thresholds`.
if (
objective_weights.nonzero().numel() > 1
torch_opt_config.is_moo
and (
self._objective_thresholds is None
or self._objective_thresholds[torch_opt_config.objective_weights != 0]
.isnan()
.any()
)
and primary_Xs_observed is not None
and X_observed is not None
):
if torch_opt_config.risk_measure is not None:
# TODO[T131759263]: modify the heuristic to support risk measures.
raise NotImplementedError(
"Objective thresholds must be provided when using risk measures."
)
self._objective_thresholds = infer_objective_thresholds(
model=model,
objective_weights=self._full_objective_weights,
outcome_constraints=full_outcome_constraints,
X_observed=primary_Xs_observed,
X_observed=X_observed,
subset_idcs=subset_idcs,
objective_thresholds=self._objective_thresholds,
)
Expand All @@ -253,66 +177,52 @@ def __init__(
objective_weights=objective_weights,
objective_thresholds=objective_thresholds,
outcome_constraints=outcome_constraints,
X_observed=primary_Xs_observed,
X_observed=X_observed,
risk_measure=torch_opt_config.risk_measure,
)
acqf_model_kwarg = (
{
"model_dict": ModelDict(
**{
name: surrogate.model
for name, surrogate in self.surrogates.items()
}
)
}
if len(self.surrogates) > 1
else {"model": model}
)
target_fidelities = {
k: v
for k, v in search_space_digest.target_values.items()
if k in search_space_digest.fidelity_features
}
input_constructor_kwargs = {
"X_baseline": unique_Xs_observed,
"X_pending": unique_Xs_pending,
"model": model,
"X_baseline": X_observed,
"X_pending": X_pending,
"objective_thresholds": objective_thresholds,
"constraints": get_outcome_constraint_transforms(
outcome_constraints=outcome_constraints
),
"objective": objective,
"posterior_transform": posterior_transform,
**acqf_model_kwarg,
**self.options,
}

if len(target_fidelities) > 0:
input_constructor_kwargs["target_fidelities"] = target_fidelities

input_constructor = get_acqf_input_constructor(botorch_acqf_class)
# Handle multi-dataset surrogates - TODO: Improve this
# If there is only one SupervisedDataset return it alone
if (
len(self.surrogates) == 1
and len(next(iter(self.surrogates.values())).training_data) == 1
):
training_data = next(iter(self.surrogates.values())).training_data[0]

# Extract the training data from the surrogate.
# If there is a single dataset, this will be the dataset itself.
# If there are multiple datasets, this will be a dict mapping the outcome names
# to the corresponding datasets.
training_data = primary_surrogate.training_data
if len(training_data) == 1:
training_data = training_data[0]
else:
tdicts = (
dict(zip(not_none(surrogate._outcomes), surrogate.training_data))
for surrogate in self.surrogates.values()
training_data = dict(
zip(none_throws(primary_surrogate._outcomes), training_data)
)
# outcome_name => Dataset
training_data = functools.reduce(lambda x, y: {**x, **y}, tdicts)

acqf_inputs = input_constructor(
training_data=training_data,
bounds=search_space_digest.bounds,
**{k: v for k, v in input_constructor_kwargs.items() if v is not None},
)
self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45]
self.X_pending: Tensor | None = unique_Xs_pending
self.X_observed: Tensor | None = unique_Xs_observed
self.X_pending: Tensor | None = X_pending
self.X_observed: Tensor | None = X_observed

@property
def botorch_acqf_class(self) -> type[AcquisitionFunction]:
Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _transform_torch_config(
model_gen_options=torch_opt_config.model_gen_options,
rounding_func=torch_opt_config.rounding_func,
opt_config_metrics=torch_opt_config.opt_config_metrics,
is_moo=torch_opt_config.is_moo,
is_moo=True, # SEBO adds an objective, so it'll always be MOO.
)

def optimize(
Expand Down
18 changes: 17 additions & 1 deletion ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, SearchSpaceExhausted
from ax.exceptions.core import AxWarning, SearchSpaceExhausted, UnsupportedError
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand Down Expand Up @@ -776,6 +776,7 @@ def test_init_moo(
objective_weights=moo_objective_weights,
outcome_constraints=outcome_constraints,
objective_thresholds=moo_objective_thresholds,
is_moo=True,
)
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
Expand Down Expand Up @@ -860,3 +861,18 @@ def test_init_moo(

def test_init_no_X_observed(self) -> None:
self.test_init_moo(with_no_X_observed=True, with_outcome_constraints=False)

def test_init_multiple_surrogates(self) -> None:
with self.assertRaisesRegex(
UnsupportedError, "currently only supports a single surrogate"
):
Acquisition(
surrogates={
"surrogate_1": self.surrogate,
"surrogate_2": self.surrogate,
},
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
botorch_acqf_class=self.botorch_acqf_class,
options=self.options,
)

0 comments on commit 42be53a

Please sign in to comment.