Skip to content

Commit

Permalink
Add should_skip setting to input constructors (facebook#2894)
Browse files Browse the repository at this point in the history
Summary:




**Context:** There are some cases where the input constructor returns n=o, specifically for repeat arms, this is currently causing an issue in generation strategy gen method because if we don't generate from a node we can't meet the TC, and move forward, however, we don't actually want to generate from that node.

**This diff** adds the setting of should skip to repeat_arm_n input constructor as it's currently the only input constructor that could enter a condition that creates a skippable node

Following diffs will:
- update the gs to appropiately handle resetting to false after a gen is completed

Reviewed By: lena-kashtelyan

Differential Revision: D64475782
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 17, 2024
1 parent 77c2418 commit ab59eea
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
) = None,
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand Down Expand Up @@ -172,6 +173,7 @@ def __init__(
)
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip

@property
def node_name(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def repeat_arm_n(
if total_n < 6:
# if the next trial is small, we don't want to waste allocation on repeat arms
# users can still manually add repeat arms if they want before allocation
# and we need to designated this node as skipped for proper transition
next_node._should_skip = True
return 0
elif total_n <= 10:
return 1
Expand Down
17 changes: 10 additions & 7 deletions ax/modelbridge/tests/test_generation_node_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,6 @@ def test_consume_all_n_constructor(self) -> None:

def test_repeat_arm_n_constructor(self) -> None:
"""Test that the repeat_arm_n_constructor returns a small percentage of n."""
small_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5},
experiment=self.experiment,
)
medium_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
Expand All @@ -79,10 +73,19 @@ def test_repeat_arm_n_constructor(self) -> None:
gs_gen_call_kwargs={"n": 11},
experiment=self.experiment,
)
self.assertEqual(small_n, 0)
self.assertEqual(medium_n, 1)
self.assertEqual(large_n, 2)

def test_repeat_arm_n_constructor_return_0(self) -> None:
small_n = NodeInputConstructors.REPEAT_N(
previous_node=None,
next_node=self.sobol_generation_node,
gs_gen_call_kwargs={"n": 5},
experiment=self.experiment,
)
self.assertEqual(small_n, 0)
self.assertTrue(self.sobol_generation_node._should_skip)

def test_remaining_n_constructor_expect_1(self) -> None:
"""Test that the remaining_n_constructor returns the remaining n."""
# should return 1 because 4 arms already exist and 5 are requested
Expand Down

0 comments on commit ab59eea

Please sign in to comment.