Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove multi-surrogate support in MBM #2957

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 5 additions & 12 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel
from ax.utils.common.constants import Keys
from ax.utils.common.kwargs import get_function_argument_names
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -74,11 +73,9 @@ def test_botorch_modular(self) -> None:
self.assertEqual(gpei.model.botorch_acqf_class, qExpectedImprovement)
self.assertEqual(gpei.model.acquisition_class, Acquisition)
self.assertEqual(gpei.model.acquisition_options, {"best_f": 0.0})
self.assertIsInstance(gpei.model.surrogates[Keys.AUTOSET_SURROGATE], Surrogate)
self.assertIsInstance(gpei.model.surrogate, Surrogate)
# SingleTaskGP should be picked.
self.assertIsInstance(
gpei.model.surrogates[Keys.AUTOSET_SURROGATE].model, SingleTaskGP
)
self.assertIsInstance(gpei.model.surrogate.model, SingleTaskGP)

gr = gpei.gen(n=1)
self.assertIsNotNone(gr.best_arm_predictions)
Expand All @@ -96,14 +93,10 @@ def test_SAASBO(self) -> None:
self.assertIsInstance(saasbo, TorchModelBridge)
self.assertEqual(saasbo._model_key, "SAASBO")
self.assertIsInstance(saasbo.model, BoTorchModel)
surrogate_specs = saasbo.model.surrogate_specs
surrogate_spec = saasbo.model.surrogate_spec
self.assertEqual(
surrogate_specs,
{
"SAASBO_Surrogate": SurrogateSpec(
botorch_model_class=SaasFullyBayesianSingleTaskGP
)
},
surrogate_spec,
SurrogateSpec(botorch_model_class=SaasFullyBayesianSingleTaskGP),
)
self.assertEqual(
saasbo.model.surrogate.botorch_model_class, SaasFullyBayesianSingleTaskGP
Expand Down
1 change: 1 addition & 0 deletions ax/models/torch/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
"instead. If you run into a use case that is not supported by MBM, "
"please raise this with an issue at https://github.com/facebook/Ax",
DeprecationWarning,
stacklevel=2,
)
self.model_constructor = model_constructor
self.model_predictor = model_predictor
Expand Down
5 changes: 2 additions & 3 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
# 2a. Handle the fully discrete search space.
# 2. Handle fully discrete search spaces.
if optimizer in (
"optimize_acqf_discrete",
"optimize_acqf_discrete_local_search",
Expand Down Expand Up @@ -384,7 +383,7 @@ def optimize(
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
# 3. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
Expand Down
Loading