From ee768decae9c7b283f93c8ac6783004c75438ded Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 4 Apr 2024 09:37:59 -0700 Subject: [PATCH] Rework TransitionCriterion storage to remove circular dep (#2320) Summary: In D52982664 I introduced a circular dep by moving some storage related logic directly onto transitioncriterion class. In the past week and half, this has been pretty annoying for folks on the team (sorry everyone!). This diff fixes that dep by: 1. Keeping all storage related logic in the storage files 2. Adding a specific check in object_from_json for ```TrialBasedCriterion``. These criterion contain lists of TrialStatuses, specifically in not_in_statuses and only_in_statuses args, which makes the standard deserialization fail to properly unpack these lists in a way that maintains the TrialStatuses type. Here we add a special method to do so. 3. All other transitioncriterion can continue using deserialize method There is a larger question here about ideally if a class inherits from SerializationMixin and all it's fields also inherit from SerializationMixin that the deserialization should eloquently handle this. I tried to find a solution for that for a bit in this period, but i kept introducing circular deps and it seems to be a larger undertaking than i have scope for at the time. Differential Revision: D55727618 --- ax/modelbridge/transition_criterion.py | 36 +--------------- ax/storage/json_store/decoder.py | 43 +++++++++++++++++++ .../json_store/tests/test_json_store.py | 4 ++ ax/utils/testing/core_stubs.py | 26 +++++++++++ 4 files changed, 75 insertions(+), 34 deletions(-) diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index a3d0c93b0f3..b129267b0eb 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -7,7 +7,7 @@ from abc import abstractmethod from logging import Logger -from typing import Any, Dict, List, Optional, Set +from typing import List, Optional, Set from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment @@ -16,13 +16,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.serialization import ( - SerializationMixin, - serialize_init_args, - TClassDecoderRegistry, - TDecoderRegistry, -) -from ax.utils.common.typeutils import not_none +from ax.utils.common.serialization import SerializationMixin, serialize_init_args logger: Logger = get_logger(__name__) @@ -146,32 +140,6 @@ def __init__( block_gen_if_met=block_gen_if_met, ) - @classmethod - def deserialize_init_args( - cls, - args: Dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, - ) -> Dict[str, Any]: - """Given a dictionary, extract the properties needed to initialize the object. - Used for storage. - """ - # import here to avoid circular import - from ax.storage.json_store.decoder import object_from_json - - decoder_registry = not_none(decoder_registry) - class_decoder_registry = class_decoder_registry or {} - init_args = super().deserialize_init_args(args=args) - - return { - key: object_from_json( - object_json=value, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for key, value in init_args.items() - } - def experiment_trials_by_status( self, experiment: Experiment, statuses: List[TrialStatus] ) -> Set[int]: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 4dfb7e42429..872c28589f4 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -39,6 +39,7 @@ ) from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import _decode_callables_from_references +from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -249,6 +250,15 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) + elif isclass(_class) and issubclass(_class, TrialBasedCriterion): + # TrialBasedCriterion contain a list of `TrialStatus` for args. + # This list needs to be unpacked by hand to properly retain the types. + return trial_transition_criteria_from_json( + class_=_class, + transition_criteria_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) elif isclass(_class) and issubclass(_class, SerializationMixin): return _class( **_class.deserialize_init_args( @@ -343,6 +353,39 @@ def generator_run_from_json( return generator_run +def trial_transition_criteria_from_json( + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to + # avoid runtime subscripting errors. + class_: Type, + transition_criteria_json: Dict[str, Any], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type` to avoid runtime subscripting errors. + decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + class_decoder_registry: Dict[ + str, Callable[[Dict[str, Any]], Any] + ] = CORE_CLASS_DECODER_REGISTRY, +) -> Optional[TransitionCriterion]: + """Load Ax transition criteria that depend on Trials from JSON. + + Since ```TrialBasedCriterion``` contain lists of ```TrialStatus``, + the json for these criterion needs to be carefully unpacked and + re-processed via ```object_from_json``` in order to maintain correct + typing. We pass in ```class_``` in order to correctly handle all classes + which inherit from ```TrialBasedCriterion``` (ex: ```MaxTrials```). + """ + new_dict = {} + for key, value in transition_criteria_json.items(): + new_val = object_from_json( + object_json=value, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + new_dict[key] = new_val + + return class_(**new_dict) + + def search_space_from_json( search_space_json: Dict[str, Any], # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index aff4c89f9a4..1a1c5d41557 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -113,7 +113,9 @@ get_surrogate, get_synthetic_runner, get_threshold_early_stopping_strategy, + get_transition_criterion_list, get_trial, + get_trial_status, get_winsorization_config, ) from ax.utils.testing.modeling_stubs import ( @@ -218,8 +220,10 @@ ("Type[Transform]", get_transform_type), ("Type[InputTransform]", get_input_transform_type), ("Type[OutcomeTransform]", get_outcome_transfrom_type), + ("TransitionCriterionList", get_transition_criterion_list), ("ThresholdEarlyStoppingStrategy", get_threshold_early_stopping_strategy), ("Trial", get_trial), + ("TrialStatus", get_trial_status), ("WinsorizationConfig", get_winsorization_config), ("SEBOAcquisition", get_sebo_acquisition_class), ] diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index b6e69af956a..9ef825819d3 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -96,6 +96,11 @@ from ax.metrics.hartmann6 import AugmentedHartmann6Metric, Hartmann6Metric from ax.modelbridge.factory import Cont_X_trans, get_factorial, get_sobol from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.modelbridge.transition_criterion import ( + MaxGenerationParallelism, + MaxTrials, + TransitionCriterion, +) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec from ax.models.torch.botorch_modular.sebo import SEBOAcquisition @@ -154,6 +159,27 @@ def get_experiment_with_map_data_type() -> Experiment: ) +def get_trial_status() -> List[TrialStatus]: + return [TrialStatus.CANDIDATE, TrialStatus.RUNNING, TrialStatus.COMPLETED] + + +def get_transition_criterion_list() -> List[TransitionCriterion]: + return [ + MaxTrials( + threshold=3, + only_in_statuses=[TrialStatus.RUNNING, TrialStatus.COMPLETED], + not_in_statuses=None, + ), + MaxGenerationParallelism( + threshold=5, + only_in_statuses=None, + not_in_statuses=[ + TrialStatus.RUNNING, + ], + ), + ] + + def get_experiment_with_custom_runner_and_metric( constrain_search_space: bool = True, immutable: bool = False,