Skip to content

Commit

Permalink
Update selection of botorch_acqf_class in BoTorchModel._instantiate_a…
Browse files Browse the repository at this point in the history
…cquisition (facebook#2909)

Summary:
Pull Request resolved: facebook#2909

The previous logic relied on imperfect proxies for `is_moo`. This information is readily available on `TorchOptConfig`, so we can directly utilize it.

Also simplified the function signature and updated the error message for robust optimization.

Reviewed By: Balandat

Differential Revision: D64545104

fbshipit-source-id: 3c517ccdff750cf5f156e83d1e9b3f2b393d80a7
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 18, 2024
1 parent e146f52 commit caaef96
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 34 deletions.
18 changes: 7 additions & 11 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata, TGenMetadata
from ax.exceptions.core import UserInputError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.models.torch.botorch import (
get_feature_importances_from_botorch_model,
get_rounding_func,
Expand Down Expand Up @@ -616,17 +616,13 @@ def _instantiate_acquisition(
"""
if not self._botorch_acqf_class:
if torch_opt_config.risk_measure is not None:
# TODO[T131759261]: Implement selection of acqf for robust opt.
# This will depend on the properties of the robust search space and
# the risk measure being used.
raise NotImplementedError
raise UnsupportedError(
"Automated selection of `botorch_acqf_class` is not supported "
"for robust optimization with risk measures. Please specify "
"`botorch_acqf_class` as part of `model_kwargs`."
)
self._botorch_acqf_class = choose_botorch_acqf_class(
pending_observations=torch_opt_config.pending_observations,
outcome_constraints=torch_opt_config.outcome_constraints,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
objective_thresholds=torch_opt_config.objective_thresholds,
objective_weights=torch_opt_config.objective_weights,
torch_opt_config=torch_opt_config
)

return self.acquisition_class(
Expand Down
21 changes: 8 additions & 13 deletions ax/models/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxError, AxWarning, UnsupportedError
from ax.models.torch_base import TorchOptConfig
from ax.models.types import TConfig
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -129,20 +130,14 @@ def choose_model_class(


def choose_botorch_acqf_class(
pending_observations: list[Tensor] | None = None,
outcome_constraints: tuple[Tensor, Tensor] | None = None,
linear_constraints: tuple[Tensor, Tensor] | None = None,
fixed_features: dict[int, float] | None = None,
objective_thresholds: Tensor | None = None,
objective_weights: Tensor | None = None,
torch_opt_config: TorchOptConfig,
) -> type[AcquisitionFunction]:
"""Chooses a BoTorch `AcquisitionFunction` class."""
if objective_thresholds is not None or (
# using objective_weights is a less-than-ideal fix given its ambiguity,
# the real fix would be to revisit the infomration passed down via
# the modelbridge (and be explicit about whether we scalarize or perform MOO)
objective_weights is not None and objective_weights.nonzero().numel() > 1
):
"""Chooses a BoTorch ``AcquisitionFunction`` class.
Current logic relies on ``TorchOptConfig.is_moo`` field to determine
whether to use qLogNEHVI (for MOO) or qLogNEI for (SOO).
"""
if torch_opt_config.is_moo:
acqf_class = qLogNoisyExpectedHypervolumeImprovement
else:
acqf_class = qLogNoisyExpectedImprovement
Expand Down
5 changes: 3 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def setUp(self) -> None:
objective_weights=self.moo_objective_weights,
objective_thresholds=self.moo_objective_thresholds,
outcome_constraints=self.moo_outcome_constraints,
is_moo=True,
)

def test_init(self) -> None:
Expand Down Expand Up @@ -874,13 +875,13 @@ def test_MOO(self, _) -> None:
torch_opt_config=self.moo_torch_opt_config,
)
mock_get_outcome_constraint_transforms.assert_called_once()
ckwargs = mock_get_outcome_constraint_transforms.call_args[1]
ckwargs = mock_get_outcome_constraint_transforms.call_args.kwargs
oc = ckwargs["outcome_constraints"]
self.assertTrue(torch.equal(oc[0], subset_outcome_constraints[0]))
self.assertTrue(torch.equal(oc[1], subset_outcome_constraints[1]))

# Check input constructor args
ckwargs = mock_input_constructor.call_args[1]
ckwargs = mock_input_constructor.call_args.kwargs
expected_kwargs = {
"constraints",
"bounds",
Expand Down
22 changes: 14 additions & 8 deletions ax/models/torch/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
use_model_list,
)
from ax.models.torch.utils import _to_inequality_constraints
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast, not_none
Expand Down Expand Up @@ -178,18 +179,23 @@ def test_choose_model_class(self) -> None:
)

def test_choose_botorch_acqf_class(self) -> None:
self.assertEqual(qLogNoisyExpectedImprovement, choose_botorch_acqf_class())
self.assertEqual(
qLogNoisyExpectedHypervolumeImprovement,
choose_botorch_acqf_class(objective_thresholds=self.objective_thresholds),
qLogNoisyExpectedImprovement,
choose_botorch_acqf_class(
torch_opt_config=TorchOptConfig(
objective_weights=torch.tensor([1.0, 0.0]),
is_moo=False,
)
),
)
self.assertEqual(
qLogNoisyExpectedHypervolumeImprovement,
choose_botorch_acqf_class(objective_weights=torch.tensor([0.5, 0.5])),
)
self.assertEqual(
qLogNoisyExpectedImprovement,
choose_botorch_acqf_class(objective_weights=torch.tensor([1.0, 0.0])),
choose_botorch_acqf_class(
torch_opt_config=TorchOptConfig(
objective_weights=torch.tensor([1.0, -1.0]),
is_moo=True,
)
),
)

def test_construct_acquisition_and_optimizer_options(self) -> None:
Expand Down

0 comments on commit caaef96

Please sign in to comment.