Skip to content

Commit

Permalink
Give generation nodes a required name (#1855)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1855

This diff adds a required name property in the GenerationNode class. This is important for (1) storage, most important and (2) being able to surface relevant errors to users that point to a specific node causing the error.

From a big picture, this is important for the GenerationStep -> GenerationNode migration

also thank lena-kashtelyan for doing the majority of this diff after a discussion re: the field unique_id failing tests for GenNodes

Reviewed By: lena-kashtelyan

Differential Revision: D49424161

fbshipit-source-id: ab30529d064834fd7e7001d3c2e166fce739e0be
  • Loading branch information
mgarrard authored and facebook-github-bot committed Sep 20, 2023
1 parent 5e8b5b2 commit b5e31e3
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 2 deletions.
8 changes: 8 additions & 0 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class GenerationNode:

model_specs: List[ModelSpec]
should_deduplicate: bool
_node_name: str
_model_spec_to_gen_from: Optional[ModelSpec] = None
# [TODO] Handle experiment passing more eloquently by enforcing experiment
# attribute is set in generation strategies class
Expand All @@ -72,10 +73,12 @@ class GenerationNode:

def __init__(
self,
node_name: str,
model_specs: List[ModelSpec],
best_model_selector: Optional[BestModelSelector] = None,
should_deduplicate: bool = False,
) -> None:
self._node_name = node_name
# While `GenerationNode` only handles a single `ModelSpec` in the `gen`
# and `_pick_fitted_model_to_gen_from` methods, we validate the
# length of `model_specs` in `_pick_fitted_model_to_gen_from` in order
Expand All @@ -85,6 +88,10 @@ def __init__(
self.best_model_selector = best_model_selector
self.should_deduplicate = should_deduplicate

@property
def node_name(self) -> str:
return self._node_name

@property
def model_spec_to_gen_from(self) -> ModelSpec:
"""Returns the cached `_model_spec_to_gen_from` or gets it from
Expand Down Expand Up @@ -463,6 +470,7 @@ def __post_init__(self) -> None:
# Factory functions may not always have a model key defined.
self.model_name = f"Unknown {model_spec.__class__.__name__}"
super().__init__(
node_name=f"GenerationStep_{str(self.index)}",
model_specs=[model_spec],
should_deduplicate=self.should_deduplicate,
)
Expand Down
3 changes: 3 additions & 0 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def __init__(self, steps: List[GenerationStep], name: Optional[str] = None) -> N
"Maximum parallelism should be None (if no limit) or a positive"
f" number. Got: {step.max_parallelism} for step {step.model_name}."
)
# TODO[mgarrard]: Validate node name uniqueness when adding node support,
# uniqueness is gaurenteed for steps currently due to list structure.
step._node_name = f"GenerationStep_{str(idx)}"
step.index = idx
step._generation_strategy = self
if not isinstance(step.model, ModelRegistryBase):
Expand Down
12 changes: 10 additions & 2 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def setUp(self) -> None:
model_kwargs={"init_position": 3},
model_gen_kwargs={"some_gen_kwarg": "some_value"},
)
self.sobol_generation_node = GenerationNode(model_specs=[self.sobol_model_spec])
self.sobol_generation_node = GenerationNode(
node_name="test", model_specs=[self.sobol_model_spec]
)
self.branin_experiment = get_branin_experiment(with_completed_trial=True)

def test_init(self) -> None:
Expand Down Expand Up @@ -72,7 +74,8 @@ def test_gen(self) -> None:

def test_gen_validates_one_model_spec(self) -> None:
generation_node = GenerationNode(
model_specs=[self.sobol_model_spec, self.sobol_model_spec]
node_name="test",
model_specs=[self.sobol_model_spec, self.sobol_model_spec],
)
# Base generation node can only handle one model spec at the moment
# (this might change in the future), so it should raise a `NotImplemented
Expand All @@ -86,6 +89,7 @@ def test_gen_validates_one_model_spec(self) -> None:
@fast_botorch_optimize
def test_properties(self) -> None:
node = GenerationNode(
node_name="test",
model_specs=[
ModelSpec(
model_enum=Models.GPEI,
Expand Down Expand Up @@ -114,9 +118,11 @@ def test_properties(self) -> None:
self.assertEqual(node.fixed_features, node.model_specs[0].fixed_features)
self.assertEqual(node.cv_results, node.model_specs[0].cv_results)
self.assertEqual(node.diagnostics, node.model_specs[0].diagnostics)
self.assertEqual(node.node_name, "test")

def test_single_fixed_features(self) -> None:
node = GenerationNode(
node_name="test",
model_specs=[
ModelSpec(
model_enum=Models.GPEI,
Expand All @@ -132,6 +138,7 @@ def test_single_fixed_features(self) -> None:

def test_multiple_same_fixed_features(self) -> None:
node = GenerationNode(
node_name="test",
model_specs=[
ModelSpec(
model_enum=Models.GPEI,
Expand Down Expand Up @@ -242,6 +249,7 @@ def setUp(self) -> None:
self.fitted_model_specs = [ms_gpei, ms_gpkg]

self.model_selection_node = GenerationNode(
node_name="test",
model_specs=self.fitted_model_specs,
best_model_selector=SingleDiagnosticBestModelSelector(
diagnostic="Fisher exact test p",
Expand Down
14 changes: 14 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ def tearDown(self) -> None:
self.discrete_model_bridge_patcher.stop()
self.registry_setup_dict_patcher.stop()

def test_unique_step_names(self) -> None:
"""This tests the name of the steps on generation strategy. The name is
inherited from the GenerationNode class, and for GenerationSteps the
name should follow the format "GenerationNode"+Stepidx.
"""
gs = GenerationStrategy(
steps=[
GenerationStep(model=Models.SOBOL, num_trials=5),
GenerationStep(model=Models.GPEI, num_trials=-1),
]
)
self.assertEqual(gs._steps[0].node_name, "GenerationStep_0")
self.assertEqual(gs._steps[1].node_name, "GenerationStep_1")

def test_name(self) -> None:
self.sobol_GS.name = "SomeGSName"
self.assertEqual(self.sobol_GS.name, "SomeGSName")
Expand Down

0 comments on commit b5e31e3

Please sign in to comment.