Skip to content

Commit

Permalink
Light GenStrategy cleanup (#2258)
Browse files Browse the repository at this point in the history
Summary:

As titled, grab-bag of no-op cleanups + coverage improvement

Differential Revision: D54639831
  • Loading branch information
mgarrard authored and facebook-github-bot committed Mar 18, 2024
1 parent 639b731 commit 0a5b6af
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 196 deletions.
281 changes: 137 additions & 144 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import wraps

from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar

import pandas as pd
from ax.core.data import Data
Expand Down Expand Up @@ -41,7 +41,8 @@
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.transition_criterion import TrialBasedCriterion
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast, not_none
from ax.utils.common.typeutils import checked_cast, checked_cast_list, not_none
from pyre_extensions import none_throws

logger: Logger = get_logger(__name__)

Expand All @@ -58,15 +59,14 @@

def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]:
"""
For use as a decorator on functions only implemented for GenerationStep based
GenerationStrategies. Mainly useful for older GenerationStrategies.
For use as a decorator on functions only implemented for ``GenerationStep``-based
``GenerationStrategies``. Mainly useful for older ``GenerationStrategies``.
"""

@wraps(f)
def impl(
self: "GenerationStrategy", *args: List[Any], **kwargs: Dict[str, Any]
) -> T:

if self.is_node_based:
raise UnsupportedError(
f"{f.__name__} is not supported for GenerationNode based"
Expand Down Expand Up @@ -111,8 +111,6 @@ class GenerationStrategy(GenerationStrategyInterface):
# it exists.
_experiment: Optional[Experiment] = None
# Trial indices as last seen by the model; updated in `_model` property setter.
# pyre-fixme[4]: Attribute must be annotated.
_seen_trial_indices_by_status = None
_model: Optional[ModelBridge] = None # Current model.

def __init__(
Expand All @@ -121,27 +119,31 @@ def __init__(
name: Optional[str] = None,
nodes: Optional[List[GenerationNode]] = None,
) -> None:
self._uses_registered_models = True
self._generator_runs = []

# validate that only one of steps or nodes is provided
# Validate that one and only one of steps or nodes is provided
if not ((steps is None) ^ (nodes is None)):
raise GenerationStrategyMisconfiguredException(
error_info="GenerationStrategy must contain either steps or nodes."
)

# pyre-ignore[8]
self._nodes = steps if steps is not None else nodes
node_based_strategy = self.is_node_based
self._nodes = none_throws(nodes if steps is None else steps)

if isinstance(steps, list) and not node_based_strategy:
# Validate correctness of steps list or nodes graph
if isinstance(steps, list) and all(
isinstance(s, GenerationStep) for s in steps
):
self._validate_and_set_step_sequence(steps=self._nodes)
elif isinstance(nodes, list) and node_based_strategy:
elif isinstance(nodes, list) and self.is_node_based:
self._validate_and_set_node_graph(nodes=nodes)
else:
# TODO[mgarrard]: Allow mix of nodes and steps
raise GenerationStrategyMisconfiguredException(
error_info="Steps must either be a GenerationStep list or a "
+ "GenerationNode list."
"`GenerationStrategy` inputs are:\n"
"`steps` (list of `GenerationStep`) or\n"
"`nodes` (list of `GenerationNode`)."
)

# Log warning if the GS uses a non-registered (factory function) model.
self._uses_registered_models = not any(
isinstance(ms, FactoryFunctionModelSpec)
for node in self._nodes
Expand All @@ -152,33 +154,28 @@ def __init__(
"Using model via callable function, "
"so optimization is not resumable if interrupted."
)
self._seen_trial_indices_by_status = None
self._generator_runs = []
# Set name to an explicit value ahead of time to avoid
# adding properties during equality checks
super().__init__(name=name or self._make_default_name())

@property
def is_node_based(self) -> bool:
"""Whether this strategy consists of GenerationNodes or GenerationSteps.
This is useful for determining initialization properties and other logic.
"""Whether this strategy consists of GenerationNodes only.
This is useful for determining initialization properties and
other logic.
"""
if any(isinstance(n, GenerationStep) for n in self._nodes):
return False
return True
return not any(isinstance(n, GenerationStep) for n in self._nodes) and all(
isinstance(n, GenerationNode) for n in self._nodes
)

@property
def name(self) -> str:
"""Name of this generation strategy. Defaults to a combination of model
names provided in generation steps.
names provided in generation steps, set at the time of the
``GenerationStrategy`` creation.
"""
if self._name is not None:
return not_none(self._name)

factory_names = (node.model_spec_to_gen_from.model_key for node in self._nodes)
# Trim the "get_" beginning of the factory function if it's there.
factory_names = (n[4:] if n[:4] == "get_" else n for n in factory_names)
self._name = "+".join(factory_names)
return not_none(self._name)
return self._name

@name.setter
def name(self, name: str) -> None:
Expand Down Expand Up @@ -346,104 +343,6 @@ def optimization_complete(self) -> bool:
"""Checks whether all nodes are completed in the generation strategy."""
return all(node.is_completed for node in self._nodes)

@step_based_gs_only
def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None:
"""Initialize and validate the steps provided to this GenerationStrategy.
Some GenerationStrategies are composed of GenerationStep objects, but we also
need to initialize the correct GenerationNode representation for these steps.
This function validates:
1. That only the last step has num_trials=-1, which indicates unlimited
trial generation is possible.
2. That each step's num_trials attrivute is either positive or -1
3. That each step's max_parallelism attribute is either None or positive
It then sets the corect TransitionCriterion and node_name attributes on the
underlying GenerationNode objects.
"""
for idx, step in enumerate(steps):
if step.num_trials == -1 and len(step.completion_criteria) < 1:
if idx < len(self._steps) - 1:
raise UserInputError(
"Only last step in generation strategy can have "
"`num_trials` set to -1 to indicate that the model in "
"the step shouldbe used to generate new trials "
"indefinitely unless completion critera present."
)
elif step.num_trials < 1 and step.num_trials != -1:
raise UserInputError(
"`num_trials` must be positive or -1 (indicating unlimited) "
"for all generation steps."
)
if step.max_parallelism is not None and step.max_parallelism < 1:
raise UserInputError(
"Maximum parallelism should be None (if no limit) or "
f"a positive number. Got: {step.max_parallelism} for "
f"step {step.model_name}."
)

step._node_name = f"GenerationStep_{str(idx)}"
step.index = idx

# Set transition_to field for all but the last step, which remains
# null.
if idx != len(self._steps):
for transition_criteria in step.transition_criteria:
if (
transition_criteria.criterion_class
!= "MaxGenerationParallelism"
):
transition_criteria._transition_to = (
f"GenerationStep_{str(idx + 1)}"
)
step._generation_strategy = self
self._curr = steps[0]

def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None:
"""Initialize and validate the node graph provided to this GenerationStrategy.
This function validates:
1. That all nodes have unique names.
2. That there is at least one node with a transition_to field.
3. That all `transition_to` attributes on a TransitionCriterion point to
another node in the same GenerationStrategy.
4. Warns if no nodes contain a transition criterion
"""
node_names = []
for node in self._nodes:
# validate that all node names are unique
if node.node_name in node_names:
raise GenerationStrategyMisconfiguredException(
error_info="All node names in a GenerationStrategy "
+ "must be unique."
)

node_names.append(node.node_name)
node._generation_strategy = self

# validate `transition_criterion`
contains_a_transition_to_argument = False
for node in self._nodes:
for transition_criteria in node.transition_criteria:
if transition_criteria.transition_to is not None:
contains_a_transition_to_argument = True
if transition_criteria.transition_to not in node_names:
raise GenerationStrategyMisconfiguredException(
error_info=f"`transition_to` argument "
f"{transition_criteria.transition_to} does not "
"correspond to any node in this GenerationStrategy."
)

# validate that at least one node has transition_to field
if len(self._nodes) > 1 and not contains_a_transition_to_argument:
logger.warning(
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
"Therefore, the GenerationStrategy will not be able to "
"move from one node to another. Please add a "
"`transition_to` argument."
)
self._curr = nodes[0]

@property
@step_based_gs_only
def _steps(self) -> List[GenerationStep]:
Expand Down Expand Up @@ -564,20 +463,17 @@ def current_generator_run_limit(

def clone_reset(self) -> GenerationStrategy:
"""Copy this generation strategy without it's state."""
if self.is_node_based:
nodes = deepcopy(self._nodes)
for n in nodes:
# Unset the generation strategy back-pointer, so the nodes are not
# associated with any generation strategy.
n._generation_strategy = None
return GenerationStrategy(name=self.name, nodes=nodes)

steps = deepcopy(self._steps)
for s in steps:
# Unset the generation strategy back-pointer, so the steps are not
cloned_nodes = deepcopy(self._nodes)
for n in cloned_nodes:
# Unset the generation strategy back-pointer, so the nodes are not
# associated with any generation strategy.
s._generation_strategy = None
return GenerationStrategy(name=self.name, steps=steps)
n._generation_strategy = None
if self.is_node_based:
return GenerationStrategy(name=self.name, nodes=cloned_nodes)

return GenerationStrategy(
name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes)
)

def _unset_non_persistent_state_fields(self) -> None:
"""Utility for testing convenience: unset fields of generation strategy
Expand All @@ -586,11 +482,108 @@ def _unset_non_persistent_state_fields(self) -> None:
strategies; call this utility on the pre-storage one first. The rest
of the fields should be identical.
"""
self._seen_trial_indices_by_status = None
self._model = None
for s in self._nodes:
s._model_spec_to_gen_from = None

@step_based_gs_only
def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None:
"""Initialize and validate the steps provided to this GenerationStrategy.
Some GenerationStrategies are composed of GenerationStep objects, but we also
need to initialize the correct GenerationNode representation for these steps.
This function validates:
1. That only the last step has num_trials=-1, which indicates unlimited
trial generation is possible.
2. That each step's num_trials attrivute is either positive or -1
3. That each step's max_parallelism attribute is either None or positive
It then sets the corect TransitionCriterion and node_name attributes on the
underlying GenerationNode objects.
"""
for idx, step in enumerate(steps):
if step.num_trials == -1 and len(step.completion_criteria) < 1:
if idx < len(self._steps) - 1:
raise UserInputError(
"Only last step in generation strategy can have "
"`num_trials` set to -1 to indicate that the model in "
"the step shouldbe used to generate new trials "
"indefinitely unless completion critera present."
)
elif step.num_trials < 1 and step.num_trials != -1:
raise UserInputError(
"`num_trials` must be positive or -1 (indicating unlimited) "
"for all generation steps."
)
if step.max_parallelism is not None and step.max_parallelism < 1:
raise UserInputError(
"Maximum parallelism should be None (if no limit) or "
f"a positive number. Got: {step.max_parallelism} for "
f"step {step.model_name}."
)

step._node_name = f"GenerationStep_{str(idx)}"
step.index = idx

# Set transition_to field for all but the last step, which remains
# null.
if idx != len(self._steps):
for transition_criteria in step.transition_criteria:
if (
transition_criteria.criterion_class
!= "MaxGenerationParallelism"
):
transition_criteria._transition_to = (
f"GenerationStep_{str(idx + 1)}"
)
step._generation_strategy = self
self._curr = steps[0]

def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None:
"""Initialize and validate the node graph provided to this GenerationStrategy.
This function validates:
1. That all nodes have unique names.
2. That there is at least one node with a transition_to field.
3. That all `transition_to` attributes on a TransitionCriterion point to
another node in the same GenerationStrategy.
4. Warns if no nodes contain a transition criterion
"""
node_names = []
for node in self._nodes:
# validate that all node names are unique
if node.node_name in node_names:
raise GenerationStrategyMisconfiguredException(
error_info="All node names in a GenerationStrategy "
+ "must be unique."
)

node_names.append(node.node_name)
node._generation_strategy = self

# validate `transition_criterion`
contains_a_transition_to_argument = False
for node in self._nodes:
for transition_criteria in node.transition_criteria:
if transition_criteria.transition_to is not None:
contains_a_transition_to_argument = True
if transition_criteria.transition_to not in node_names:
raise GenerationStrategyMisconfiguredException(
error_info=f"`transition_to` argument "
f"{transition_criteria.transition_to} does not "
"correspond to any node in this GenerationStrategy."
)

# validate that at least one node has transition_to field
if len(self._nodes) > 1 and not contains_a_transition_to_argument:
logger.warning(
"None of the nodes in this GenerationStrategy "
"contain a `transition_to` argument in their transition_criteria. "
"Therefore, the GenerationStrategy will not be able to "
"move from one node to another. Please add a "
"`transition_to` argument."
)
self._curr = nodes[0]

@step_based_gs_only
def _step_repr(self, step_str_rep: str) -> str:
"""Return the string representation of the steps in a GenerationStrategy
Expand Down
Loading

0 comments on commit 0a5b6af

Please sign in to comment.