diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index d8a70e4a5fc..e138fdadfdc 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -12,7 +12,7 @@ from functools import wraps from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, TypeVar import pandas as pd from ax.core.data import Data @@ -41,7 +41,8 @@ from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase from ax.modelbridge.transition_criterion import TrialBasedCriterion from ax.utils.common.logger import _round_floats_for_logging, get_logger -from ax.utils.common.typeutils import checked_cast, not_none +from ax.utils.common.typeutils import checked_cast, checked_cast_list, not_none +from pyre_extensions import none_throws logger: Logger = get_logger(__name__) @@ -58,15 +59,14 @@ def step_based_gs_only(f: Callable[..., T]) -> Callable[..., T]: """ - For use as a decorator on functions only implemented for GenerationStep based - GenerationStrategies. Mainly useful for older GenerationStrategies. + For use as a decorator on functions only implemented for ``GenerationStep``-based + ``GenerationStrategies``. Mainly useful for older ``GenerationStrategies``. """ @wraps(f) def impl( self: "GenerationStrategy", *args: List[Any], **kwargs: Dict[str, Any] ) -> T: - if self.is_node_based: raise UnsupportedError( f"{f.__name__} is not supported for GenerationNode based" @@ -111,8 +111,6 @@ class GenerationStrategy(GenerationStrategyInterface): # it exists. _experiment: Optional[Experiment] = None # Trial indices as last seen by the model; updated in `_model` property setter. - # pyre-fixme[4]: Attribute must be annotated. - _seen_trial_indices_by_status = None _model: Optional[ModelBridge] = None # Current model. def __init__( @@ -121,27 +119,31 @@ def __init__( name: Optional[str] = None, nodes: Optional[List[GenerationNode]] = None, ) -> None: - self._uses_registered_models = True - self._generator_runs = [] - - # validate that only one of steps or nodes is provided + # Validate that one and only one of steps or nodes is provided if not ((steps is None) ^ (nodes is None)): raise GenerationStrategyMisconfiguredException( error_info="GenerationStrategy must contain either steps or nodes." ) + # pyre-ignore[8] - self._nodes = steps if steps is not None else nodes - node_based_strategy = self.is_node_based + self._nodes = none_throws(nodes if steps is None else steps) - if isinstance(steps, list) and not node_based_strategy: + # Validate correctness of steps list or nodes graph + if isinstance(steps, list) and all( + isinstance(s, GenerationStep) for s in steps + ): self._validate_and_set_step_sequence(steps=self._nodes) - elif isinstance(nodes, list) and node_based_strategy: + elif isinstance(nodes, list) and self.is_node_based: self._validate_and_set_node_graph(nodes=nodes) else: + # TODO[mgarrard]: Allow mix of nodes and steps raise GenerationStrategyMisconfiguredException( - error_info="Steps must either be a GenerationStep list or a " - + "GenerationNode list." + "`GenerationStrategy` inputs are:\n" + "`steps` (list of `GenerationStep`) or\n" + "`nodes` (list of `GenerationNode`)." ) + + # Log warning if the GS uses a non-registered (factory function) model. self._uses_registered_models = not any( isinstance(ms, FactoryFunctionModelSpec) for node in self._nodes @@ -152,33 +154,28 @@ def __init__( "Using model via callable function, " "so optimization is not resumable if interrupted." ) - self._seen_trial_indices_by_status = None + self._generator_runs = [] # Set name to an explicit value ahead of time to avoid # adding properties during equality checks super().__init__(name=name or self._make_default_name()) @property def is_node_based(self) -> bool: - """Whether this strategy consists of GenerationNodes or GenerationSteps. - This is useful for determining initialization properties and other logic. + """Whether this strategy consists of GenerationNodes only. + This is useful for determining initialization properties and + other logic. """ - if any(isinstance(n, GenerationStep) for n in self._nodes): - return False - return True + return not any(isinstance(n, GenerationStep) for n in self._nodes) and all( + isinstance(n, GenerationNode) for n in self._nodes + ) @property def name(self) -> str: """Name of this generation strategy. Defaults to a combination of model - names provided in generation steps. + names provided in generation steps, set at the time of the + ``GenerationStrategy`` creation. """ - if self._name is not None: - return not_none(self._name) - - factory_names = (node.model_spec_to_gen_from.model_key for node in self._nodes) - # Trim the "get_" beginning of the factory function if it's there. - factory_names = (n[4:] if n[:4] == "get_" else n for n in factory_names) - self._name = "+".join(factory_names) - return not_none(self._name) + return self._name @name.setter def name(self, name: str) -> None: @@ -346,104 +343,6 @@ def optimization_complete(self) -> bool: """Checks whether all nodes are completed in the generation strategy.""" return all(node.is_completed for node in self._nodes) - @step_based_gs_only - def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: - """Initialize and validate the steps provided to this GenerationStrategy. - - Some GenerationStrategies are composed of GenerationStep objects, but we also - need to initialize the correct GenerationNode representation for these steps. - This function validates: - 1. That only the last step has num_trials=-1, which indicates unlimited - trial generation is possible. - 2. That each step's num_trials attrivute is either positive or -1 - 3. That each step's max_parallelism attribute is either None or positive - It then sets the corect TransitionCriterion and node_name attributes on the - underlying GenerationNode objects. - """ - for idx, step in enumerate(steps): - if step.num_trials == -1 and len(step.completion_criteria) < 1: - if idx < len(self._steps) - 1: - raise UserInputError( - "Only last step in generation strategy can have " - "`num_trials` set to -1 to indicate that the model in " - "the step shouldbe used to generate new trials " - "indefinitely unless completion critera present." - ) - elif step.num_trials < 1 and step.num_trials != -1: - raise UserInputError( - "`num_trials` must be positive or -1 (indicating unlimited) " - "for all generation steps." - ) - if step.max_parallelism is not None and step.max_parallelism < 1: - raise UserInputError( - "Maximum parallelism should be None (if no limit) or " - f"a positive number. Got: {step.max_parallelism} for " - f"step {step.model_name}." - ) - - step._node_name = f"GenerationStep_{str(idx)}" - step.index = idx - - # Set transition_to field for all but the last step, which remains - # null. - if idx != len(self._steps): - for transition_criteria in step.transition_criteria: - if ( - transition_criteria.criterion_class - != "MaxGenerationParallelism" - ): - transition_criteria._transition_to = ( - f"GenerationStep_{str(idx + 1)}" - ) - step._generation_strategy = self - self._curr = steps[0] - - def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: - """Initialize and validate the node graph provided to this GenerationStrategy. - - This function validates: - 1. That all nodes have unique names. - 2. That there is at least one node with a transition_to field. - 3. That all `transition_to` attributes on a TransitionCriterion point to - another node in the same GenerationStrategy. - 4. Warns if no nodes contain a transition criterion - """ - node_names = [] - for node in self._nodes: - # validate that all node names are unique - if node.node_name in node_names: - raise GenerationStrategyMisconfiguredException( - error_info="All node names in a GenerationStrategy " - + "must be unique." - ) - - node_names.append(node.node_name) - node._generation_strategy = self - - # validate `transition_criterion` - contains_a_transition_to_argument = False - for node in self._nodes: - for transition_criteria in node.transition_criteria: - if transition_criteria.transition_to is not None: - contains_a_transition_to_argument = True - if transition_criteria.transition_to not in node_names: - raise GenerationStrategyMisconfiguredException( - error_info=f"`transition_to` argument " - f"{transition_criteria.transition_to} does not " - "correspond to any node in this GenerationStrategy." - ) - - # validate that at least one node has transition_to field - if len(self._nodes) > 1 and not contains_a_transition_to_argument: - logger.warning( - "None of the nodes in this GenerationStrategy " - "contain a `transition_to` argument in their transition_criteria. " - "Therefore, the GenerationStrategy will not be able to " - "move from one node to another. Please add a " - "`transition_to` argument." - ) - self._curr = nodes[0] - @property @step_based_gs_only def _steps(self) -> List[GenerationStep]: @@ -564,20 +463,17 @@ def current_generator_run_limit( def clone_reset(self) -> GenerationStrategy: """Copy this generation strategy without it's state.""" - if self.is_node_based: - nodes = deepcopy(self._nodes) - for n in nodes: - # Unset the generation strategy back-pointer, so the nodes are not - # associated with any generation strategy. - n._generation_strategy = None - return GenerationStrategy(name=self.name, nodes=nodes) - - steps = deepcopy(self._steps) - for s in steps: - # Unset the generation strategy back-pointer, so the steps are not + cloned_nodes = deepcopy(self._nodes) + for n in cloned_nodes: + # Unset the generation strategy back-pointer, so the nodes are not # associated with any generation strategy. - s._generation_strategy = None - return GenerationStrategy(name=self.name, steps=steps) + n._generation_strategy = None + if self.is_node_based: + return GenerationStrategy(name=self.name, nodes=cloned_nodes) + + return GenerationStrategy( + name=self.name, steps=checked_cast_list(GenerationStep, cloned_nodes) + ) def _unset_non_persistent_state_fields(self) -> None: """Utility for testing convenience: unset fields of generation strategy @@ -586,11 +482,108 @@ def _unset_non_persistent_state_fields(self) -> None: strategies; call this utility on the pre-storage one first. The rest of the fields should be identical. """ - self._seen_trial_indices_by_status = None self._model = None for s in self._nodes: s._model_spec_to_gen_from = None + @step_based_gs_only + def _validate_and_set_step_sequence(self, steps: List[GenerationStep]) -> None: + """Initialize and validate the steps provided to this GenerationStrategy. + + Some GenerationStrategies are composed of GenerationStep objects, but we also + need to initialize the correct GenerationNode representation for these steps. + This function validates: + 1. That only the last step has num_trials=-1, which indicates unlimited + trial generation is possible. + 2. That each step's num_trials attrivute is either positive or -1 + 3. That each step's max_parallelism attribute is either None or positive + It then sets the corect TransitionCriterion and node_name attributes on the + underlying GenerationNode objects. + """ + for idx, step in enumerate(steps): + if step.num_trials == -1 and len(step.completion_criteria) < 1: + if idx < len(self._steps) - 1: + raise UserInputError( + "Only last step in generation strategy can have " + "`num_trials` set to -1 to indicate that the model in " + "the step shouldbe used to generate new trials " + "indefinitely unless completion critera present." + ) + elif step.num_trials < 1 and step.num_trials != -1: + raise UserInputError( + "`num_trials` must be positive or -1 (indicating unlimited) " + "for all generation steps." + ) + if step.max_parallelism is not None and step.max_parallelism < 1: + raise UserInputError( + "Maximum parallelism should be None (if no limit) or " + f"a positive number. Got: {step.max_parallelism} for " + f"step {step.model_name}." + ) + + step._node_name = f"GenerationStep_{str(idx)}" + step.index = idx + + # Set transition_to field for all but the last step, which remains + # null. + if idx != len(self._steps): + for transition_criteria in step.transition_criteria: + if ( + transition_criteria.criterion_class + != "MaxGenerationParallelism" + ): + transition_criteria._transition_to = ( + f"GenerationStep_{str(idx + 1)}" + ) + step._generation_strategy = self + self._curr = steps[0] + + def _validate_and_set_node_graph(self, nodes: List[GenerationNode]) -> None: + """Initialize and validate the node graph provided to this GenerationStrategy. + + This function validates: + 1. That all nodes have unique names. + 2. That there is at least one node with a transition_to field. + 3. That all `transition_to` attributes on a TransitionCriterion point to + another node in the same GenerationStrategy. + 4. Warns if no nodes contain a transition criterion + """ + node_names = [] + for node in self._nodes: + # validate that all node names are unique + if node.node_name in node_names: + raise GenerationStrategyMisconfiguredException( + error_info="All node names in a GenerationStrategy " + + "must be unique." + ) + + node_names.append(node.node_name) + node._generation_strategy = self + + # validate `transition_criterion` + contains_a_transition_to_argument = False + for node in self._nodes: + for transition_criteria in node.transition_criteria: + if transition_criteria.transition_to is not None: + contains_a_transition_to_argument = True + if transition_criteria.transition_to not in node_names: + raise GenerationStrategyMisconfiguredException( + error_info=f"`transition_to` argument " + f"{transition_criteria.transition_to} does not " + "correspond to any node in this GenerationStrategy." + ) + + # validate that at least one node has transition_to field + if len(self._nodes) > 1 and not contains_a_transition_to_argument: + logger.warning( + "None of the nodes in this GenerationStrategy " + "contain a `transition_to` argument in their transition_criteria. " + "Therefore, the GenerationStrategy will not be able to " + "move from one node to another. Please add a " + "`transition_to` argument." + ) + self._curr = nodes[0] + @step_based_gs_only def _step_repr(self, step_str_rep: str) -> str: """Return the string representation of the steps in a GenerationStrategy diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 63f34e8e08a..6e091c0027a 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -129,6 +129,53 @@ def setUp(self) -> None: ] ) + # Set up the node-based generation strategy for testing. + self.sobol_criterion = [ + MaxTrials( + threshold=5, + transition_to="GPEI_node", + block_gen_if_met=True, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + ] + self.gpei_criterion = [ + MaxTrials( + threshold=2, + transition_to=None, + block_gen_if_met=True, + only_in_statuses=None, + not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], + ) + ] + self.sobol_model_spec = ModelSpec( + model_enum=Models.SOBOL, + model_kwargs=self.step_model_kwargs, + model_gen_kwargs={}, + ) + self.gpei_model_spec = ModelSpec( + model_enum=Models.GPEI, + model_kwargs=self.step_model_kwargs, + model_gen_kwargs={}, + ) + self.sobol_node = GenerationNode( + node_name="sobol_node", + transition_criteria=self.sobol_criterion, + model_specs=[self.sobol_model_spec], + gen_unlimited_trials=False, + ) + self.gpei_node = GenerationNode( + node_name="GPEI_node", + transition_criteria=self.gpei_criterion, + model_specs=[self.gpei_model_spec], + gen_unlimited_trials=False, + ) + + self.sobol_GPEI_GS_nodes = GenerationStrategy( + name="Sobol+GPEI_Nodes", + nodes=[self.sobol_node, self.gpei_node], + ) + def tearDown(self) -> None: self.torch_model_bridge_patcher.stop() self.discrete_model_bridge_patcher.stop() @@ -1163,58 +1210,13 @@ def test_gs_setup_with_nodes(self) -> None: def test_gs_with_generation_nodes(self) -> None: "Simple test of a SOBOL + GPEI GenerationStrategy composed of GenerationNodes" - sobol_criterion = [ - MaxTrials( - threshold=5, - transition_to="GPEI_node", - block_gen_if_met=True, - only_in_statuses=None, - not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - ) - ] - gpei_criterion = [ - MaxTrials( - threshold=2, - transition_to=None, - block_gen_if_met=True, - only_in_statuses=None, - not_in_statuses=[TrialStatus.FAILED, TrialStatus.ABANDONED], - ) - ] - sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, - model_kwargs=self.step_model_kwargs, - model_gen_kwargs={}, - ) - gpei_model_spec = ModelSpec( - model_enum=Models.GPEI, - model_kwargs=self.step_model_kwargs, - model_gen_kwargs={}, - ) - sobol_node = GenerationNode( - node_name="sobol_node", - transition_criteria=sobol_criterion, - model_specs=[sobol_model_spec], - gen_unlimited_trials=False, - ) - gpei_node = GenerationNode( - node_name="GPEI_node", - transition_criteria=gpei_criterion, - model_specs=[gpei_model_spec], - gen_unlimited_trials=False, - ) - - sobol_GPEI_GS_nodes = GenerationStrategy( - name="Sobol+GPEI_Nodes", - nodes=[sobol_node, gpei_node], - ) exp = get_branin_experiment() - self.assertEqual(sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") + self.assertEqual(self.sobol_GPEI_GS_nodes.name, "Sobol+GPEI_Nodes") for i in range(7): - g = sobol_GPEI_GS_nodes.gen(exp) + g = self.sobol_GPEI_GS_nodes.gen(exp) exp.new_trial(generator_run=g).run() - self.assertEqual(len(sobol_GPEI_GS_nodes._generator_runs), i + 1) + self.assertEqual(len(self.sobol_GPEI_GS_nodes._generator_runs), i + 1) if i > 4: self.mock_torch_model_bridge.assert_called() else: @@ -1263,6 +1265,19 @@ def test_gs_with_generation_nodes(self) -> None: del ms["generated_points"] self.assertEqual(ms, {"init_position": i + 1}) + def test_clone_reset_nodes(self) -> None: + """Test that node-based generation strategy is appropriately reset + when cloned with `clone_reset`. + """ + exp = get_branin_experiment() + for i in range(7): + g = self.sobol_GPEI_GS_nodes.gen(exp) + exp.new_trial(generator_run=g).run() + self.assertEqual(len(self.sobol_GPEI_GS_nodes._generator_runs), i + 1) + gs_clone = self.sobol_GPEI_GS_nodes.clone_reset() + self.assertEqual(gs_clone.name, self.sobol_GPEI_GS_nodes.name) + self.assertEqual(gs_clone._generator_runs, []) + def test_gs_with_nodes_and_blocking_criteria(self) -> None: sobol_model_spec = ModelSpec( model_enum=Models.SOBOL, @@ -1380,8 +1395,6 @@ def test_generation_strategy_eq_no_print(self) -> None: GenerationStep(model=Models.GPEI, num_trials=-1), ] ) - print(gs1) - print(gs2) self.assertEqual(gs1, gs2) # ------------- Testing helpers (put tests above this line) ------------- diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 096eb1529f1..d7760e5807b 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1405,7 +1405,6 @@ def test_EncodeDecodeGenerationStrategyReducedState(self) -> None: # Reloaded generation strategy will not have attributes associated with fitting # the model until after it's used to fit the model or generate candidates, so # we unset those attributes here and compare equality of the rest. - generation_strategy._seen_trial_indices_by_status = None generation_strategy._model = None self.assertEqual(new_generation_strategy, generation_strategy) # Model should be successfully restored in generation strategy even with @@ -1520,7 +1519,6 @@ def test_UpdateGenerationStrategy(self) -> None: # Reloaded generation strategy will not have attributes associated with fitting # the model until after it's used to fit the model or generate candidates, so # we unset those attributes here and compare equality of the rest. - generation_strategy._seen_trial_indices_by_status = None generation_strategy._model = None self.assertEqual(generation_strategy, loaded_generation_strategy) self.assertIsNotNone(loaded_generation_strategy._experiment)