diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index a557fe76843..7e42b67bd5b 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -144,6 +144,7 @@ def __init__( ) = None, previous_node_name: str | None = None, trial_type: str | None = None, + should_skip: bool = False, ) -> None: self._node_name = node_name # Check that the model specs have unique model keys. @@ -172,6 +173,7 @@ def __init__( ) self._previous_node_name = previous_node_name self._trial_type = trial_type + self._should_skip = should_skip @property def node_name(self) -> str: diff --git a/ax/modelbridge/generation_node_input_constructors.py b/ax/modelbridge/generation_node_input_constructors.py index 0f9fbb02bb5..7c1a0820a46 100644 --- a/ax/modelbridge/generation_node_input_constructors.py +++ b/ax/modelbridge/generation_node_input_constructors.py @@ -176,6 +176,8 @@ def repeat_arm_n( if total_n < 6: # if the next trial is small, we don't want to waste allocation on repeat arms # users can still manually add repeat arms if they want before allocation + # and we need to designated this node as skipped for proper transition + next_node._should_skip = True return 0 elif total_n <= 10: return 1 diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index f1bec49352f..e69fc51960c 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -61,12 +61,6 @@ def test_consume_all_n_constructor(self) -> None: def test_repeat_arm_n_constructor(self) -> None: """Test that the repeat_arm_n_constructor returns a small percentage of n.""" - small_n = NodeInputConstructors.REPEAT_N( - previous_node=None, - next_node=self.sobol_generation_node, - gs_gen_call_kwargs={"n": 5}, - experiment=self.experiment, - ) medium_n = NodeInputConstructors.REPEAT_N( previous_node=None, next_node=self.sobol_generation_node, @@ -79,10 +73,19 @@ def test_repeat_arm_n_constructor(self) -> None: gs_gen_call_kwargs={"n": 11}, experiment=self.experiment, ) - self.assertEqual(small_n, 0) self.assertEqual(medium_n, 1) self.assertEqual(large_n, 2) + def test_repeat_arm_n_constructor_return_0(self) -> None: + small_n = NodeInputConstructors.REPEAT_N( + previous_node=None, + next_node=self.sobol_generation_node, + gs_gen_call_kwargs={"n": 5}, + experiment=self.experiment, + ) + self.assertEqual(small_n, 0) + self.assertTrue(self.sobol_generation_node._should_skip) + def test_remaining_n_constructor_expect_1(self) -> None: """Test that the remaining_n_constructor returns the remaining n.""" # should return 1 because 4 arms already exist and 5 are requested