From 0b64c697a5115ca73f675cea1b75d5c653377913 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Wed, 16 Oct 2024 21:08:44 -0700 Subject: [PATCH] Add should_skip setting to input constructors (#2894) Summary: **Context:** There are some cases where the input constructor returns n=o, specifically for repeat arms, this is currently causing an issue in generation strategy gen method because if we don't generate from a node we can't meet the TC, and move forward, however, we don't actually want to generate from that node. **This diff** adds the setting of should skip to repeat_arm_n input constructor as it's currently the only input constructor that could enter a condition that creates a skippable node Following diffs will: - update the gs to appropiately handle resetting to false after a gen is completed Reviewed By: lena-kashtelyan Differential Revision: D64475782 --- ax/modelbridge/generation_node.py | 2 ++ .../generation_node_input_constructors.py | 2 ++ .../test_generation_node_input_constructors.py | 17 ++++++++++------- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index a557fe76843..7e42b67bd5b 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -144,6 +144,7 @@ def __init__( ) = None, previous_node_name: str | None = None, trial_type: str | None = None, + should_skip: bool = False, ) -> None: self._node_name = node_name # Check that the model specs have unique model keys. @@ -172,6 +173,7 @@ def __init__( ) self._previous_node_name = previous_node_name self._trial_type = trial_type + self._should_skip = should_skip @property def node_name(self) -> str: diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index 0f9fbb02bb5..7c1a0820a46 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -176,6 +176,8 @@ def repeat_arm_n( if total_n < 6: # if the next trial is small, we don't want to waste allocation on repeat arms # users can still manually add repeat arms if they want before allocation + # and we need to designated this node as skipped for proper transition + next_node._should_skip = True return 0 elif total_n <= 10: return 1 diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index f1bec49352f..e69fc51960c 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -61,12 +61,6 @@ def test_consume_all_n_constructor(self) -> None: def test_repeat_arm_n_constructor(self) -> None: """Test that the repeat_arm_n_constructor returns a small percentage of n.""" - small_n = NodeInputConstructors.REPEAT_N( - previous_node=None, - next_node=self.sobol_generation_node, - gs_gen_call_kwargs={"n": 5}, - experiment=self.experiment, - ) medium_n = NodeInputConstructors.REPEAT_N( previous_node=None, next_node=self.sobol_generation_node, @@ -79,10 +73,19 @@ def test_repeat_arm_n_constructor(self) -> None: gs_gen_call_kwargs={"n": 11}, experiment=self.experiment, ) - self.assertEqual(small_n, 0) self.assertEqual(medium_n, 1) self.assertEqual(large_n, 2) + def test_repeat_arm_n_constructor_return_0(self) -> None: + small_n = NodeInputConstructors.REPEAT_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 5}, + experiment=self.experiment, + ) + self.assertEqual(small_n, 0) + self.assertTrue(self.sobol_generation_node._should_skip) + def test_remaining_n_constructor_expect_1(self) -> None: """Test that the remaining_n_constructor returns the remaining n.""" # should return 1 because 4 arms already exist and 5 are requested