From 7dcbfa3f92e0d981c4b56b137b34e53f41a04237 Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Thu, 4 Apr 2024 07:46:22 -0700 Subject: [PATCH] Rework TransitionCriterion storage to remove circular dep (#2320) Summary: In D52982664 I introduced a circular dep by moving some storage related logic directly onto transitioncriterion class. In the past week and half, this has been pretty annoying for folks on the team (sorry everyone!). This diff fixes that dep by: 1. Keeping all storage related logic in the storage files 2. Adding a specific check in object_from_json for ```TrialBasedCriterion``. These criterion contain lists of TrialStatuses, specifically in not_in_statuses and only_in_statuses args, which makes the standard deserialization fail to properly unpack these lists in a way that maintains the TrialStatuses type. Here we add a special method to do so. 3. All other transitioncriterion can continue using deserialize method There is a larger question here about ideally if a class inherits from SerializationMixin and all it's fields also inherit from SerializationMixin that the deserialization should eloquently handle this. I tried to find a solution for that for a bit in this period, but i kept introducing circular deps and it seems to be a larger undertaking than i have scope for at the time. Differential Revision: D55727618 --- ax/core/base_trial.py | 7 ++- ax/modelbridge/transition_criterion.py | 36 +--------------- ax/storage/json_store/decoder.py | 43 +++++++++++++++++++ .../json_store/tests/test_json_store.py | 2 + ax/utils/common/serialization.py | 14 ++++-- ax/utils/testing/core_stubs.py | 4 ++ 6 files changed, 67 insertions(+), 39 deletions(-) diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index dada3b0f42d..bfe457d7691 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -25,6 +25,7 @@ from ax.core.types import TCandidateMetadata, TEvaluationOutcome from ax.exceptions.core import UnsupportedError from ax.utils.common.base import SortableBase +from ax.utils.common.serialization import SerializationMixin from ax.utils.common.typeutils import not_none @@ -33,7 +34,11 @@ from ax import core # noqa F401 -class TrialStatus(int, Enum): +class TrialStatus( + SerializationMixin, + int, + Enum, +): """Enum of trial status. General lifecycle of a trial is::: diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index a3d0c93b0f3..b129267b0eb 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -7,7 +7,7 @@ from abc import abstractmethod from logging import Logger -from typing import Any, Dict, List, Optional, Set +from typing import List, Optional, Set from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment @@ -16,13 +16,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.serialization import ( - SerializationMixin, - serialize_init_args, - TClassDecoderRegistry, - TDecoderRegistry, -) -from ax.utils.common.typeutils import not_none +from ax.utils.common.serialization import SerializationMixin, serialize_init_args logger: Logger = get_logger(__name__) @@ -146,32 +140,6 @@ def __init__( block_gen_if_met=block_gen_if_met, ) - @classmethod - def deserialize_init_args( - cls, - args: Dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, - ) -> Dict[str, Any]: - """Given a dictionary, extract the properties needed to initialize the object. - Used for storage. - """ - # import here to avoid circular import - from ax.storage.json_store.decoder import object_from_json - - decoder_registry = not_none(decoder_registry) - class_decoder_registry = class_decoder_registry or {} - init_args = super().deserialize_init_args(args=args) - - return { - key: object_from_json( - object_json=value, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for key, value in init_args.items() - } - def experiment_trials_by_status( self, experiment: Experiment, statuses: List[TrialStatus] ) -> Set[int]: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 4dfb7e42429..872c28589f4 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -39,6 +39,7 @@ ) from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import _decode_callables_from_references +from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -249,6 +250,15 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) + elif isclass(_class) and issubclass(_class, TrialBasedCriterion): + # TrialBasedCriterion contain a list of `TrialStatus` for args. + # This list needs to be unpacked by hand to properly retain the types. + return trial_transition_criteria_from_json( + class_=_class, + transition_criteria_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) elif isclass(_class) and issubclass(_class, SerializationMixin): return _class( **_class.deserialize_init_args( @@ -343,6 +353,39 @@ def generator_run_from_json( return generator_run +def trial_transition_criteria_from_json( + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to + # avoid runtime subscripting errors. + class_: Type, + transition_criteria_json: Dict[str, Any], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type` to avoid runtime subscripting errors. + decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + class_decoder_registry: Dict[ + str, Callable[[Dict[str, Any]], Any] + ] = CORE_CLASS_DECODER_REGISTRY, +) -> Optional[TransitionCriterion]: + """Load Ax transition criteria that depend on Trials from JSON. + + Since ```TrialBasedCriterion``` contain lists of ```TrialStatus``, + the json for these criterion needs to be carefully unpacked and + re-processed via ```object_from_json``` in order to maintain correct + typing. We pass in ```class_``` in order to correctly handle all classes + which inherit from ```TrialBasedCriterion``` (ex: ```MaxTrials```). + """ + new_dict = {} + for key, value in transition_criteria_json.items(): + new_val = object_from_json( + object_json=value, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + new_dict[key] = new_val + + return class_(**new_dict) + + def search_space_from_json( search_space_json: Dict[str, Any], # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index aff4c89f9a4..bb83650f759 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -114,6 +114,7 @@ get_synthetic_runner, get_threshold_early_stopping_strategy, get_trial, + get_trial_status, get_winsorization_config, ) from ax.utils.testing.modeling_stubs import ( @@ -155,6 +156,7 @@ ("FixedParameter", get_fixed_parameter), ("GammaPrior", get_gamma_prior), ("GenerationStrategy", partial(get_generation_strategy, with_experiment=True)), + ("TrialStatus", get_trial_status), ( "GenerationStrategy", partial( diff --git a/ax/utils/common/serialization.py b/ax/utils/common/serialization.py index 0fa936c9b4b..036b2dc4168 100644 --- a/ax/utils/common/serialization.py +++ b/ax/utils/common/serialization.py @@ -101,9 +101,12 @@ def serialize_init_args( return properties -# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to -# avoid runtime subscripting errors. -def extract_init_args(args: Dict[str, Any], class_: Type) -> Dict[str, Any]: +def extract_init_args( + args: Dict[str, Any], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to + # avoid runtime subscripting errors. + class_: Type, +) -> Dict[str, Any]: """Given a dictionary, extract the arguments required for the given class's constructor. """ @@ -147,4 +150,7 @@ def deserialize_init_args( """Given a dictionary, deserialize the properties needed to initialize the object. Used for storage. """ - return extract_init_args(args=args, class_=cls) + return extract_init_args( + args=args, + class_=cls, + ) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index b6e69af956a..49e26413066 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -154,6 +154,10 @@ def get_experiment_with_map_data_type() -> Experiment: ) +def get_trial_status() -> List[TrialStatus]: + return [TrialStatus.CANDIDATE, TrialStatus.RUNNING, TrialStatus.COMPLETED] + + def get_experiment_with_custom_runner_and_metric( constrain_search_space: bool = True, immutable: bool = False,