From 406e3e495d0694ea617ac02e29835980b25c1b7b Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 16 Oct 2024 21:46:55 -0700 Subject: [PATCH 1/3] Add should_skip property to GenerationNode (#2891) Summary: There are some cases where the input constructor returns n=o, specifically for repeat arms, this is currently causing an issue in generaiton strategy gen method because if we don't generate from a node we can't meet the TC, and move forward, however, we don't actually want to generate from that node. This simply exposes a prop on gen node to expose this, and default it to false Following diffs will: - update TC.is_met() to accept a GenNode instead of just the name - update input constructor to properly set should skip - update the gs to appropiately handle resetting to false after a gen is completed Reviewed By: lena-kashtelyan Differential Revision: D64444258 --- ax/modelbridge/generation_node.py | 3 +++ ax/modelbridge/tests/test_generation_strategy.py | 11 ++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 3310822f3c6..a557fe76843 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -96,6 +96,8 @@ class GenerationNode(SerializationMixin, SortableBase): set during transition from one ``GenerationNode`` to the next. Can be overwritten if multiple transitions occur between nodes, and will always store the most recent previous ``GenerationNode`` name. + should_skip: Whether to skip this node during generation time. Defaults to + False, and can only currently be set to True via ``NodeInputConstructors`` Note for developers: by "model" here we really mean an Ax ModelBridge object, which contains an Ax Model under the hood. We call it "model" here to simplify and focus @@ -118,6 +120,7 @@ class GenerationNode(SerializationMixin, SortableBase): ] _previous_node_name: str | None = None _trial_type: str | None = None + _should_skip: bool = False # [TODO] Handle experiment passing more eloquently by enforcing experiment # attribute is set in generation strategies class diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index eab2187b2ff..073f31a8799 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -1380,7 +1380,7 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: # check first call is 6 (from the previous trial having 6 arms) self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 6) - def test_gs_initializes_all_previous_node_to_none(self) -> None: + def test_gs_initializes_default_props_correctly(self) -> None: """Test that all previous nodes are initialized to None""" node_1 = GenerationNode( node_name="node_1", @@ -1401,13 +1401,18 @@ def test_gs_initializes_all_previous_node_to_none(self) -> None: node_3, ], ) - with self.subTest("after initialization all should be none"): + with self.subTest("after initialization all previous nodes should be none"): for node in gs._nodes: self.assertIsNone(node._previous_node_name) self.assertIsNone(node.previous_node) - with self.subTest("check previous node nodes after being set"): + with self.subTest("check previous node after it is set"): gs._nodes[1]._previous_node_name = "node_1" self.assertEqual(gs._nodes[1].previous_node, node_1) + with self.subTest( + "after initialization all nodes should have should_skip set to False" + ): + for node in gs._nodes: + self.assertFalse(node._should_skip) def test_gs_with_generation_nodes(self) -> None: "Simple test of a SOBOL + MBM GenerationStrategy composed of GenerationNodes" From 7c4e0235cddc809550b626d4a896ef15110ebe4d Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 16 Oct 2024 21:46:55 -0700 Subject: [PATCH 2/3] Add current node to TC.is_met() method to access node properties (#2893) Summary: **Context:** There are some cases where the input constructor returns n=o, specifically for repeat arms, this is currently causing an issue in generation strategy gen method because if we don't generate from a node we can't meet the TC, and move forward, however, we don't actually want to generate from that node. **This diff** updates tc.is_met() to accept a full generation node which allows us to access the should_skip property we added in the previous diff Following diffs will: - update input constructor to properly set should skip - update the gs to appropiately handle resetting to false after a gen is completed Reviewed By: lena-kashtelyan Differential Revision: D64445871 --- .../tests/test_transition_criterion.py | 30 +++++++++++++++++-- ax/modelbridge/transition_criterion.py | 17 +++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index 55d4cb3c5cd..7c0a00356ed 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -344,7 +344,33 @@ def test_auto_transition(self) -> None: gs.gen(experiment=experiment) self.assertEqual(gs.current_node_name, "sobol_2") - def test_is_single_obejective_does_not_transition(self) -> None: + def test_auto_with_should_skip_node(self) -> None: + experiment = self.branin_experiment + gs = GenerationStrategy( + name="test", + nodes=[ + GenerationNode( + node_name="sobol_1", + model_specs=[self.sobol_model_spec], + transition_criteria=[ + AutoTransitionAfterGen(transition_to="sobol_2") + ], + ), + GenerationNode( + node_name="sobol_2", model_specs=[self.sobol_model_spec] + ), + ], + ) + gs._nodes[0]._should_skip = True + self.assertTrue( + gs._nodes[0] + .transition_criteria[0] + .is_met( + experiment=experiment, curr_node_name="sobol_1", curr_node=gs._nodes[0] + ) + ) + + def test_is_single_objective_does_not_transition(self) -> None: exp = self.branin_experiment exp.optimization_config = get_branin_multi_objective_optimization_config() gs = GenerationStrategy( @@ -368,7 +394,7 @@ def test_is_single_obejective_does_not_transition(self) -> None: self.assertEqual(gr2._generation_node_name, "sobol_1") self.assertEqual(gs.current_node_name, "sobol_1") - def test_is_single_obejective_transitions(self) -> None: + def test_is_single_objective_transitions(self) -> None: exp = self.branin_experiment gs = GenerationStrategy( name="test", diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index e0c3c2b8388..26afe4e6d5a 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -11,6 +11,8 @@ from logging import Logger from typing import Collection +from ax import modelbridge + from ax.core import MultiObjectiveOptimizationConfig from ax.core.auxiliary import AuxiliaryExperimentPurpose @@ -82,7 +84,9 @@ def is_met( experiment: Experiment, trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, + # todo @mgarrard remove this once we no longer have steps curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: """If the criterion of this TransitionCriterion is met, returns True.""" pass @@ -149,10 +153,18 @@ def is_met( trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: """Return True as soon as any GeneratorRun is generated by this GenerationNode. """ + # Handle edge case where the InputConstructor for a GenerationNode + # with this criterion requests no arms to be generated, therefore, indicating + # that this GenerationNode should be skipped and so we can transition to the + # next node as defined by this criterion. + if curr_node is not None and curr_node._should_skip: + return True + return node_that_generated_last_gr == curr_node_name def block_continued_generation_error( @@ -203,6 +215,7 @@ def is_met( trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: """Return True if the optimization config is not of type ``MultiObjectiveOptimizationConfig``.""" @@ -360,6 +373,7 @@ def is_met( block_continued_generation: bool | None = False, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: """Returns if this criterion has been met given its constraints. Args: @@ -662,6 +676,7 @@ def is_met( trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: # TODO: @mgarrard replace fetch_data with lookup_data data = experiment.fetch_data(metrics=[experiment.metrics[self.metric_name]]) @@ -781,6 +796,7 @@ def is_met( trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: """Check if the experiment has auxiliary experiments for certain purpose.""" aux_exp_purposes = set(experiment.auxiliary_experiments_by_purpose.keys()) @@ -835,6 +851,7 @@ def is_met( trials_from_node: set[int] | None = None, node_that_generated_last_gr: str | None = None, curr_node_name: str | None = None, + curr_node: modelbridge.generation_node.GenerationNode | None = None, ) -> bool: return len(experiment.trial_indices_by_status[self.status]) >= self.threshold From 791b95fc02a0f8a171ea6cb77435dc56fd7db592 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 16 Oct 2024 21:46:55 -0700 Subject: [PATCH 3/3] Add should_skip setting to input constructors (#2894) Summary: **Context:** There are some cases where the input constructor returns n=o, specifically for repeat arms, this is currently causing an issue in generation strategy gen method because if we don't generate from a node we can't meet the TC, and move forward, however, we don't actually want to generate from that node. **This diff** adds the setting of should skip to repeat_arm_n input constructor as it's currently the only input constructor that could enter a condition that creates a skippable node Following diffs will: - update the gs to appropiately handle resetting to false after a gen is completed Reviewed By: lena-kashtelyan Differential Revision: D64475782 --- ax/modelbridge/generation_node.py | 2 ++ .../generation_node_input_constructors.py | 2 ++ .../test_generation_node_input_constructors.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index a557fe76843..7e42b67bd5b 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -144,6 +144,7 @@ def __init__( ) = None, previous_node_name: str | None = None, trial_type: str | None = None, + should_skip: bool = False, ) -> None: self._node_name = node_name # Check that the model specs have unique model keys. @@ -172,6 +173,7 @@ def __init__( ) self._previous_node_name = previous_node_name self._trial_type = trial_type + self._should_skip = should_skip @property def node_name(self) -> str: diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index 0f9fbb02bb5..7c1a0820a46 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -176,6 +176,8 @@ def repeat_arm_n( if total_n < 6: # if the next trial is small, we don't want to waste allocation on repeat arms # users can still manually add repeat arms if they want before allocation + # and we need to designated this node as skipped for proper transition + next_node._should_skip = True return 0 elif total_n <= 10: return 1 diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index f1bec49352f..e69fc51960c 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -61,12 +61,6 @@ def test_consume_all_n_constructor(self) -> None: def test_repeat_arm_n_constructor(self) -> None: """Test that the repeat_arm_n_constructor returns a small percentage of n.""" - small_n = NodeInputConstructors.REPEAT_N( - previous_node=None, - next_node=self.sobol_generation_node, - gs_gen_call_kwargs={"n": 5}, - experiment=self.experiment, - ) medium_n = NodeInputConstructors.REPEAT_N( previous_node=None, next_node=self.sobol_generation_node, @@ -79,10 +73,19 @@ def test_repeat_arm_n_constructor(self) -> None: gs_gen_call_kwargs={"n": 11}, experiment=self.experiment, ) - self.assertEqual(small_n, 0) self.assertEqual(medium_n, 1) self.assertEqual(large_n, 2) + def test_repeat_arm_n_constructor_return_0(self) -> None: + small_n = NodeInputConstructors.REPEAT_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 5}, + experiment=self.experiment, + ) + self.assertEqual(small_n, 0) + self.assertTrue(self.sobol_generation_node._should_skip) + def test_remaining_n_constructor_expect_1(self) -> None: """Test that the remaining_n_constructor returns the remaining n.""" # should return 1 because 4 arms already exist and 5 are requested