diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index dcf03a741db..212697ce760 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -63,6 +63,7 @@ class GenerationNode: model_specs: List[ModelSpec] should_deduplicate: bool + _node_name: str _model_spec_to_gen_from: Optional[ModelSpec] = None # [TODO] Handle experiment passing more eloquently by enforcing experiment # attribute is set in generation strategies class @@ -72,10 +73,12 @@ class GenerationNode: def __init__( self, + node_name: str, model_specs: List[ModelSpec], best_model_selector: Optional[BestModelSelector] = None, should_deduplicate: bool = False, ) -> None: + self._node_name = node_name # While `GenerationNode` only handles a single `ModelSpec` in the `gen` # and `_pick_fitted_model_to_gen_from` methods, we validate the # length of `model_specs` in `_pick_fitted_model_to_gen_from` in order @@ -85,6 +88,10 @@ def __init__( self.best_model_selector = best_model_selector self.should_deduplicate = should_deduplicate + @property + def node_name(self) -> str: + return self._node_name + @property def model_spec_to_gen_from(self) -> ModelSpec: """Returns the cached `_model_spec_to_gen_from` or gets it from @@ -463,6 +470,7 @@ def __post_init__(self) -> None: # Factory functions may not always have a model key defined. self.model_name = f"Unknown {model_spec.__class__.__name__}" super().__init__( + node_name=f"GenerationStep_{str(self.index)}", model_specs=[model_spec], should_deduplicate=self.should_deduplicate, ) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index e6ab5b203b2..4e4e08ef831 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -98,6 +98,9 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N "Maximum parallelism should be None (if no limit) or a positive" f" number. Got: {step.max_parallelism} for step {step.model_name}." ) + # TODO[mgarrard]: Validate node name uniqueness when adding node support, + # uniqueness is gaurenteed for steps currently due to list structure. + step._node_name = f"GenerationStep_{str(idx)}" step.index = idx step._generation_strategy = self if not isinstance(step.model, ModelRegistryBase): diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index e59a52d4683..e1dae665ea4 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -29,7 +29,9 @@ def setUp(self) -> None: model_kwargs={"init_position": 3}, model_gen_kwargs={"some_gen_kwarg": "some_value"}, ) - self.sobol_generation_node = GenerationNode(model_specs=[self.sobol_model_spec]) + self.sobol_generation_node = GenerationNode( + node_name="test", model_specs=[self.sobol_model_spec] + ) self.branin_experiment = get_branin_experiment(with_completed_trial=True) def test_init(self) -> None: @@ -72,7 +74,8 @@ def test_gen(self) -> None: def test_gen_validates_one_model_spec(self) -> None: generation_node = GenerationNode( - model_specs=[self.sobol_model_spec, self.sobol_model_spec] + node_name="test", + model_specs=[self.sobol_model_spec, self.sobol_model_spec], ) # Base generation node can only handle one model spec at the moment # (this might change in the future), so it should raise a `NotImplemented @@ -86,6 +89,7 @@ def test_gen_validates_one_model_spec(self) -> None: @fast_botorch_optimize def test_properties(self) -> None: node = GenerationNode( + node_name="test", model_specs=[ ModelSpec( model_enum=Models.GPEI, @@ -114,9 +118,11 @@ def test_properties(self) -> None: self.assertEqual(node.fixed_features, node.model_specs[0].fixed_features) self.assertEqual(node.cv_results, node.model_specs[0].cv_results) self.assertEqual(node.diagnostics, node.model_specs[0].diagnostics) + self.assertEqual(node.node_name, "test") def test_single_fixed_features(self) -> None: node = GenerationNode( + node_name="test", model_specs=[ ModelSpec( model_enum=Models.GPEI, @@ -132,6 +138,7 @@ def test_single_fixed_features(self) -> None: def test_multiple_same_fixed_features(self) -> None: node = GenerationNode( + node_name="test", model_specs=[ ModelSpec( model_enum=Models.GPEI, @@ -242,6 +249,7 @@ def setUp(self) -> None: self.fitted_model_specs = [ms_gpei, ms_gpkg] self.model_selection_node = GenerationNode( + node_name="test", model_specs=self.fitted_model_specs, best_model_selector=SingleDiagnosticBestModelSelector( diagnostic="Fisher exact test p", diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 2ae551be5dd..aedf7cef919 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -120,6 +120,20 @@ def tearDown(self) -> None: self.discrete_model_bridge_patcher.stop() self.registry_setup_dict_patcher.stop() + def test_unique_step_names(self) -> None: + """This tests the name of the steps on generation strategy. The name is + inherited from the GenerationNode class, and for GenerationSteps the + name should follow the format "GenerationNode"+Stepidx. + """ + gs = GenerationStrategy( + steps=[ + GenerationStep(model=Models.SOBOL, num_trials=5), + GenerationStep(model=Models.GPEI, num_trials=-1), + ] + ) + self.assertEqual(gs._steps[0].node_name, "GenerationStep_0") + self.assertEqual(gs._steps[1].node_name, "GenerationStep_1") + def test_name(self) -> None: self.sobol_GS.name = "SomeGSName" self.assertEqual(self.sobol_GS.name, "SomeGSName")