Skip to content

Commit

Permalink
Remove multi-surrogate support from Acquisition classes (#2949)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2949

Multiple surrogates is not supported by any Acquisition classes and is being deprecated. This diff replaces the `surrogates: dict[str, Surrogate]` input to `Acquisition` classes with a simple `surrogate: Surrogate` input and cleans up the relevant code.

A following diff will remove multiple-surrogate support from the `BoTorchModel` class.

Reviewed By: Balandat

Differential Revision: D64875386

fbshipit-source-id: 19f26f3c69f1e3c91d1204b84b4a645c70f557d1
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 24, 2024
1 parent 8dcb808 commit b1f8a75
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 97 deletions.
54 changes: 12 additions & 42 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError
from ax.exceptions.core import SearchSpaceExhausted
from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand Down Expand Up @@ -65,9 +65,8 @@ class Acquisition(Base):
of `BoTorchModel` and is not meant to be used outside of it.
Args:
surrogates: Dict of name => Surrogate model pairs, with which this acquisition
surrogate: The Surrogate model, with which this acquisition
function will be used.
NOTE: Only a single surrogate is currently supported!
search_space_digest: A SearchSpaceDigest object containing metadata
about the search space (e.g. bounds, parameter types).
torch_opt_config: A TorchOptConfig object containing optimization
Expand All @@ -78,30 +77,24 @@ class Acquisition(Base):
Function` in BoTorch.
"""

surrogates: dict[str, Surrogate]
surrogate: Surrogate
acqf: AcquisitionFunction
options: dict[str, Any]

def __init__(
self,
surrogates: dict[str, Surrogate],
surrogate: Surrogate,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
botorch_acqf_class: type[AcquisitionFunction],
options: dict[str, Any] | None = None,
) -> None:
if len(surrogates) > 1:
raise UnsupportedError(
"The `Acquisition` class currently only supports a single surrogate."
)

self.surrogates = surrogates
self.surrogate = surrogate
self.options = options or {}
primary_surrogate = next(iter(self.surrogates.values()))

# Extract pending and observed points.
X_pending, X_observed = _get_X_pending_and_observed(
Xs=primary_surrogate.Xs,
Xs=surrogate.Xs,
objective_weights=torch_opt_config.objective_weights,
bounds=search_space_digest.bounds,
pending_observations=torch_opt_config.pending_observations,
Expand All @@ -120,7 +113,7 @@ def __init__(
# Subset model only to the outcomes we need for the optimization.
if self.options.pop(Keys.SUBSET_MODEL, True):
subset_model_results = subset_model(
model=primary_surrogate.model,
model=surrogate.model,
objective_weights=torch_opt_config.objective_weights,
outcome_constraints=torch_opt_config.outcome_constraints,
objective_thresholds=torch_opt_config.objective_thresholds,
Expand All @@ -131,7 +124,7 @@ def __init__(
objective_thresholds = subset_model_results.objective_thresholds
subset_idcs = subset_model_results.indices
else:
model = primary_surrogate.model
model = surrogate.model
objective_weights = torch_opt_config.objective_weights
outcome_constraints = torch_opt_config.outcome_constraints
objective_thresholds = torch_opt_config.objective_thresholds
Expand Down Expand Up @@ -203,13 +196,11 @@ def __init__(
# If there is a single dataset, this will be the dataset itself.
# If there are multiple datasets, this will be a dict mapping the outcome names
# to the corresponding datasets.
training_data = primary_surrogate.training_data
training_data = surrogate.training_data
if len(training_data) == 1:
training_data = training_data[0]
else:
training_data = dict(
zip(none_throws(primary_surrogate._outcomes), training_data)
)
training_data = dict(zip(none_throws(surrogate._outcomes), training_data))

acqf_inputs = input_constructor(
training_data=training_data,
Expand All @@ -230,35 +221,14 @@ def dtype(self) -> torch.dtype | None:
"""Torch data type of the tensors in the training data used in the model,
of which this ``Acquisition`` is a subcomponent.
"""
dtypes = {
label: surrogate.dtype for label, surrogate in self.surrogates.items()
}

dtypes_list = list(dtypes.values())
if dtypes_list.count(dtypes_list[0]) != len(dtypes_list):
raise ValueError(
f"Expected all Surrogates to have same dtype, found {dtypes}"
)

return dtypes_list[0]
return self.surrogate.dtype

@property
def device(self) -> torch.device | None:
"""Torch device type of the tensors in the training data used in the model,
of which this ``Acquisition`` is a subcomponent.
"""

devices = {
label: surrogate.device for label, surrogate in self.surrogates.items()
}

devices_list = list(devices.values())
if devices_list.count(devices_list[0]) != len(devices_list):
raise ValueError(
f"Expected all Surrogates to have same device, found {devices}"
)

return devices_list[0]
return self.surrogate.device

@property
def objective_thresholds(self) -> Tensor | None:
Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def _instantiate_acquisition(
)

return self.acquisition_class(
surrogates=self.surrogates,
surrogate=self.surrogate,
botorch_acqf_class=self.botorch_acqf_class,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
Expand Down
9 changes: 2 additions & 7 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from botorch.acquisition.acquisition import AcquisitionFunction
Expand Down Expand Up @@ -58,16 +57,12 @@ class SEBOAcquisition(Acquisition):

def __init__(
self,
surrogates: dict[str, Surrogate],
surrogate: Surrogate,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
botorch_acqf_class: type[AcquisitionFunction],
options: dict[str, Any] | None = None,
) -> None:
if len(surrogates) > 1:
raise ValueError("SEBO does not support support multiple surrogates.")
surrogate = surrogates[Keys.ONLY_SURROGATE]

tkwargs: dict[str, Any] = {"dtype": surrogate.dtype, "device": surrogate.device}
options = {} if options is None else options
self.penalty_name: str = options.pop("penalty", "L0_norm")
Expand Down Expand Up @@ -123,7 +118,7 @@ def __init__(
if self.penalty_name == "L0_norm":
self.deterministic_model._f.a.fill_(1e-6)
super().__init__(
surrogates={"sebo": surrogate_f},
surrogate=surrogate_f,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config_sebo,
botorch_acqf_class=qLogNoisyExpectedHypervolumeImprovement,
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,8 @@ def best_out_of_sample_point(
# Avoiding circular import between `Surrogate` and `Acquisition`.
from ax.models.torch.botorch_modular.acquisition import Acquisition

acqf = Acquisition( # TODO: For multi-fidelity, might need diff. class.
surrogates={"self": self},
acqf = Acquisition(
surrogate=self,
botorch_acqf_class=acqf_class,
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config,
Expand Down
39 changes: 12 additions & 27 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError
from ax.exceptions.core import SearchSpaceExhausted
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
Expand Down Expand Up @@ -178,7 +178,7 @@ def get_acquisition_function(
botorch_acqf_class=(
DummyOneShotAcquisitionFunction if one_shot else self.botorch_acqf_class
),
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
torch_opt_config=dataclasses.replace(
self.torch_opt_config, fixed_features=fixed_features or {}
Expand All @@ -194,7 +194,7 @@ def test_init_raises_when_missing_acqf_cls(self) -> None:
with self.assertRaisesRegex(TypeError, ".* missing .* 'botorch_acqf_class'"):
# pyre-ignore[20]: Argument `botorch_acqf_class` expected.
Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
)
Expand All @@ -210,7 +210,7 @@ def test_init(
mock_get_X: Mock,
) -> None:
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
botorch_acqf_class=self.botorch_acqf_class,
Expand All @@ -234,7 +234,7 @@ def test_init(

# Call `subset_model` only when needed
mock_subset_model.assert_called_with(
model=acquisition.surrogates["surrogate"].model,
model=acquisition.surrogate.model,
objective_weights=self.objective_weights,
outcome_constraints=self.outcome_constraints,
objective_thresholds=self.objective_thresholds,
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_init_with_subset_model_false(
return_value=self.constraints,
) as mock_get_outcome_constraint_transforms:
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
botorch_acqf_class=self.botorch_acqf_class,
Expand All @@ -275,14 +275,14 @@ def test_init_with_subset_model_false(
# Check `get_botorch_objective_and_transform` kwargs
mock_get_objective_and_transform.assert_called_once()
_, ckwargs = mock_get_objective_and_transform.call_args
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
self.assertIs(ckwargs["model"], acquisition.surrogate.model)
self.assertIs(ckwargs["objective_weights"], self.objective_weights)
self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints)
self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1]))
# Check final `acqf` creation
self.mock_input_constructor.assert_called_once()
_, ckwargs = self.mock_input_constructor.call_args
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
self.assertIs(ckwargs["model"], acquisition.surrogate.model)
self.assertIs(ckwargs["objective"], botorch_objective)
self.assertTrue(torch.equal(ckwargs["X_pending"], self.pending_observations[0]))
for k, v in self.options.items():
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_optimize_discrete(self) -> None:
inequality_constraints=None,
)

expected_choices = torch.tensor([elt for elt in all_possible_choices])
expected_choices = torch.tensor(all_possible_choices)
expected_avoid = torch.cat([self.X, self.pending_observations[0]], dim=-2)

kwargs = mock_optimize_acqf_discrete.call_args.kwargs
Expand Down Expand Up @@ -702,7 +702,7 @@ def test_init_moo(
is_moo=True,
)
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
botorch_acqf_class=acqf_class,
search_space_digest=self.search_space_digest,
torch_opt_config=torch_opt_config,
Expand Down Expand Up @@ -736,7 +736,7 @@ def test_init_moo(
)
)
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
botorch_acqf_class=acqf_class,
torch_opt_config=dataclasses.replace(
Expand All @@ -757,7 +757,7 @@ def test_init_moo(
self.assertTrue(np.isnan(acquisition.objective_thresholds[2].item()))
# With partial thresholds.
acquisition = Acquisition(
surrogates={"surrogate": self.surrogate},
surrogate=self.surrogate,
search_space_digest=self.search_space_digest,
botorch_acqf_class=acqf_class,
torch_opt_config=dataclasses.replace(
Expand All @@ -784,18 +784,3 @@ def test_init_moo(

def test_init_no_X_observed(self) -> None:
self.test_init_moo(with_no_X_observed=True, with_outcome_constraints=False)

def test_init_multiple_surrogates(self) -> None:
with self.assertRaisesRegex(
UnsupportedError, "currently only supports a single surrogate"
):
Acquisition(
surrogates={
"surrogate_1": self.surrogate,
"surrogate_2": self.surrogate,
},
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
botorch_acqf_class=self.botorch_acqf_class,
options=self.options,
)
21 changes: 3 additions & 18 deletions ax/models/torch/tests/test_sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_acquisition_function(
) -> SEBOAcquisition:
return SEBOAcquisition(
botorch_acqf_class=qNoisyExpectedHypervolumeImprovement,
surrogates={Keys.ONLY_SURROGATE: self.surrogates},
surrogate=self.surrogates,
search_space_digest=self.search_space_digest,
torch_opt_config=dataclasses.replace(
torch_opt_config or self.torch_opt_config,
Expand All @@ -133,7 +133,7 @@ def test_init(self) -> None:
options={"target_point": self.target_point},
)
# Check that determinstic metric is added to surrogate
surrogate = acquisition1.surrogates["sebo"]
surrogate = acquisition1.surrogate
model_list = not_none(surrogate._model)
self.assertIsInstance(model_list, ModelList)
self.assertIsInstance(model_list.models[0], SingleTaskGP)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_init(self) -> None:
options={"penalty": "L1_norm", "target_point": self.target_point},
)
self.assertEqual(acquisition2.penalty_name, "L1_norm")
surrogate = acquisition2.surrogates["sebo"]
surrogate = acquisition2.surrogate
model_list = not_none(surrogate._model)
self.assertIsInstance(model_list.models[1]._f, functools.partial)
self.assertIs(model_list.models[1]._f.func, L1_norm_func)
Expand All @@ -181,21 +181,6 @@ def test_init(self) -> None:
options={"penalty": "L2_norm", "target_point": self.target_point},
)

# assert error raise if multiple surrogates are given
with self.assertRaisesRegex(
ValueError, "SEBO does not support support multiple surrogates."
):
SEBOAcquisition(
botorch_acqf_class=qNoisyExpectedHypervolumeImprovement,
surrogates={
Keys.ONLY_SURROGATE: self.surrogates,
"sebo2": self.surrogates,
},
search_space_digest=self.search_space_digest,
torch_opt_config=self.torch_opt_config,
options=self.options,
)

# assert error raise if target point is not given
with self.assertRaisesRegex(ValueError, "please provide target point."):
self.get_acquisition_function(options={"penalty": "L1_norm"})
Expand Down

0 comments on commit b1f8a75

Please sign in to comment.