Skip to content

Commit

Permalink
Move TestGSNoMBMocks uptop (it's the better test class and we shoul…
Browse files Browse the repository at this point in the history
…d seek to add to it) (#2827)

Summary:
Pull Request resolved: #2827

As titled, just moving up a test so we use it more

Reviewed By: saitcakmak, Balandat

Differential Revision: D63914988

fbshipit-source-id: fe12282a71d617dcab730842d6783055ee10ff1b
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Oct 16, 2024
1 parent 0a46348 commit e36bbc4
Showing 1 changed file with 73 additions and 73 deletions.
146 changes: 73 additions & 73 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,79 @@
from ax.utils.testing.mock import fast_botorch_optimize


class TestGenerationStrategyWithoutModelBridgeMocks(TestCase):
"""The test class above heavily mocks the modelbridge. This makes it
difficult to test certain aspects of the GS. This is an alternative
test class that makes use of mocking rather sparingly.
"""

@fast_botorch_optimize
@patch(
"ax.modelbridge.generation_node._extract_model_state_after_gen",
wraps=_extract_model_state_after_gen,
)
def test_with_model_selection(self, mock_model_state: Mock) -> None:
"""Test that a GS with a model selection node functions correctly."""
best_model_selector = MagicMock(autospec=SingleDiagnosticBestModelSelector)
best_model_idx = 0
best_model_selector.best_model.side_effect = lambda model_specs: model_specs[
best_model_idx
]
gs = GenerationStrategy(
name="Sobol+MBM/BO_MIXED",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[ModelSpec(model_enum=Models.SOBOL)],
transition_criteria=[
MaxTrials(threshold=2, transition_to="MBM/BO_MIXED")
],
),
GenerationNode(
node_name="MBM/BO_MIXED",
model_specs=[
ModelSpec(model_enum=Models.BOTORCH_MODULAR),
ModelSpec(model_enum=Models.BO_MIXED),
],
best_model_selector=best_model_selector,
),
],
)
exp = get_branin_experiment(with_completed_trial=True)
# Gen with Sobol.
exp.new_trial(gs.gen(experiment=exp))
# Model state is not extracted since there is no past GR.
mock_model_state.assert_not_called()
exp.new_trial(gs.gen(experiment=exp))
# Model state is extracted since there is a past GR.
mock_model_state.assert_called_once()
mock_model_state.reset_mock()
# Gen with MBM/BO_MIXED.
mbm_gr_1 = gs.gen(experiment=exp)
# Model state is not extracted since there is no past GR from this node.
mock_model_state.assert_not_called()
mbm_gr_2 = gs.gen(experiment=exp)
# Model state is extracted only once, since there is a GR from only
# one of these models.
mock_model_state.assert_called_once()
# Verify that it was extracted from the previous GR.
self.assertIs(mock_model_state.call_args.kwargs["generator_run"], mbm_gr_1)
# Change the best model and verify that it generates as well.
best_model_idx = 1
mixed_gr_1 = gs.gen(experiment=exp)
# Only one new call for the MBM model.
self.assertEqual(mock_model_state.call_count, 2)
gs.gen(experiment=exp)
# Two new calls, since we have a GR from the mixed model as well.
self.assertEqual(mock_model_state.call_count, 4)
self.assertIs(
mock_model_state.call_args_list[-2].kwargs["generator_run"], mbm_gr_2
)
self.assertIs(
mock_model_state.call_args_list[-1].kwargs["generator_run"], mixed_gr_1
)


class TestGenerationStrategy(TestCase):
def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -1954,76 +2027,3 @@ def _run_GS_for_N_rounds(
trial.mark_completed()

return could_gen


class TestGenerationStrategyWithoutModelBridgeMocks(TestCase):
"""The test class above heavily mocks the modelbridge. This makes it
difficult to test certain aspects of the GS. This is an alternative
test class that makes use of mocking rather sparingly.
"""

@fast_botorch_optimize
@patch(
"ax.modelbridge.generation_node._extract_model_state_after_gen",
wraps=_extract_model_state_after_gen,
)
def test_with_model_selection(self, mock_model_state: Mock) -> None:
"""Test that a GS with a model selection node functions correctly."""
best_model_selector = MagicMock(autospec=SingleDiagnosticBestModelSelector)
best_model_idx = 0
best_model_selector.best_model.side_effect = lambda model_specs: model_specs[
best_model_idx
]
gs = GenerationStrategy(
name="Sobol+MBM/BO_MIXED",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[ModelSpec(model_enum=Models.SOBOL)],
transition_criteria=[
MaxTrials(threshold=2, transition_to="MBM/BO_MIXED")
],
),
GenerationNode(
node_name="MBM/BO_MIXED",
model_specs=[
ModelSpec(model_enum=Models.BOTORCH_MODULAR),
ModelSpec(model_enum=Models.BO_MIXED),
],
best_model_selector=best_model_selector,
),
],
)
exp = get_branin_experiment(with_completed_trial=True)
# Gen with Sobol.
exp.new_trial(gs.gen(experiment=exp))
# Model state is not extracted since there is no past GR.
mock_model_state.assert_not_called()
exp.new_trial(gs.gen(experiment=exp))
# Model state is extracted since there is a past GR.
mock_model_state.assert_called_once()
mock_model_state.reset_mock()
# Gen with MBM/BO_MIXED.
mbm_gr_1 = gs.gen(experiment=exp)
# Model state is not extracted since there is no past GR from this node.
mock_model_state.assert_not_called()
mbm_gr_2 = gs.gen(experiment=exp)
# Model state is extracted only once, since there is a GR from only
# one of these models.
mock_model_state.assert_called_once()
# Verify that it was extracted from the previous GR.
self.assertIs(mock_model_state.call_args.kwargs["generator_run"], mbm_gr_1)
# Change the best model and verify that it generates as well.
best_model_idx = 1
mixed_gr_1 = gs.gen(experiment=exp)
# Only one new call for the MBM model.
self.assertEqual(mock_model_state.call_count, 2)
gs.gen(experiment=exp)
# Two new calls, since we have a GR from the mixed model as well.
self.assertEqual(mock_model_state.call_count, 4)
self.assertIs(
mock_model_state.call_args_list[-2].kwargs["generator_run"], mbm_gr_2
)
self.assertIs(
mock_model_state.call_args_list[-1].kwargs["generator_run"], mixed_gr_1
)

0 comments on commit e36bbc4

Please sign in to comment.