diff --git a/ax/modelbridge/transition_criterion.py b/ax/modelbridge/transition_criterion.py index a3d0c93b0f3..b129267b0eb 100644 --- a/ax/modelbridge/transition_criterion.py +++ b/ax/modelbridge/transition_criterion.py @@ -7,7 +7,7 @@ from abc import abstractmethod from logging import Logger -from typing import Any, Dict, List, Optional, Set +from typing import List, Optional, Set from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment @@ -16,13 +16,7 @@ from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger -from ax.utils.common.serialization import ( - SerializationMixin, - serialize_init_args, - TClassDecoderRegistry, - TDecoderRegistry, -) -from ax.utils.common.typeutils import not_none +from ax.utils.common.serialization import SerializationMixin, serialize_init_args logger: Logger = get_logger(__name__) @@ -146,32 +140,6 @@ def __init__( block_gen_if_met=block_gen_if_met, ) - @classmethod - def deserialize_init_args( - cls, - args: Dict[str, Any], - decoder_registry: Optional[TDecoderRegistry] = None, - class_decoder_registry: Optional[TClassDecoderRegistry] = None, - ) -> Dict[str, Any]: - """Given a dictionary, extract the properties needed to initialize the object. - Used for storage. - """ - # import here to avoid circular import - from ax.storage.json_store.decoder import object_from_json - - decoder_registry = not_none(decoder_registry) - class_decoder_registry = class_decoder_registry or {} - init_args = super().deserialize_init_args(args=args) - - return { - key: object_from_json( - object_json=value, - decoder_registry=decoder_registry, - class_decoder_registry=class_decoder_registry, - ) - for key, value in init_args.items() - } - def experiment_trials_by_status( self, experiment: Experiment, statuses: List[TrialStatus] ) -> Set[int]: diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 4dfb7e42429..872c28589f4 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -39,6 +39,7 @@ ) from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.registry import _decode_callables_from_references +from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion from ax.models.torch.botorch_modular.model import SurrogateSpec from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.storage.json_store.decoders import ( @@ -249,6 +250,15 @@ def object_from_json( object_json["outcome_transform_options"] = ( outcome_transform_options_json ) + elif isclass(_class) and issubclass(_class, TrialBasedCriterion): + # TrialBasedCriterion contain a list of `TrialStatus` for args. + # This list needs to be unpacked by hand to properly retain the types. + return trial_transition_criteria_from_json( + class_=_class, + transition_criteria_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) elif isclass(_class) and issubclass(_class, SerializationMixin): return _class( **_class.deserialize_init_args( @@ -343,6 +353,39 @@ def generator_run_from_json( return generator_run +def trial_transition_criteria_from_json( + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to + # avoid runtime subscripting errors. + class_: Type, + transition_criteria_json: Dict[str, Any], + # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use + # `typing.Type` to avoid runtime subscripting errors. + decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, + # pyre-fixme[2]: Parameter annotation cannot contain `Any`. + class_decoder_registry: Dict[ + str, Callable[[Dict[str, Any]], Any] + ] = CORE_CLASS_DECODER_REGISTRY, +) -> Optional[TransitionCriterion]: + """Load Ax transition criteria that depend on Trials from JSON. + + Since ```TrialBasedCriterion``` contain lists of ```TrialStatus``, + the json for these criterion needs to be carefully unpacked and + re-processed via ```object_from_json``` in order to maintain correct + typing. We pass in ```class_``` in order to correctly handle all classes + which inherit from ```TrialBasedCriterion``` (ex: ```MaxTrials```). + """ + new_dict = {} + for key, value in transition_criteria_json.items(): + new_val = object_from_json( + object_json=value, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + new_dict[key] = new_val + + return class_(**new_dict) + + def search_space_from_json( search_space_json: Dict[str, Any], # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index aff4c89f9a4..bb83650f759 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -114,6 +114,7 @@ get_synthetic_runner, get_threshold_early_stopping_strategy, get_trial, + get_trial_status, get_winsorization_config, ) from ax.utils.testing.modeling_stubs import ( @@ -155,6 +156,7 @@ ("FixedParameter", get_fixed_parameter), ("GammaPrior", get_gamma_prior), ("GenerationStrategy", partial(get_generation_strategy, with_experiment=True)), + ("TrialStatus", get_trial_status), ( "GenerationStrategy", partial( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index b6e69af956a..49e26413066 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -154,6 +154,10 @@ def get_experiment_with_map_data_type() -> Experiment: ) +def get_trial_status() -> List[TrialStatus]: + return [TrialStatus.CANDIDATE, TrialStatus.RUNNING, TrialStatus.COMPLETED] + + def get_experiment_with_custom_runner_and_metric( constrain_search_space: bool = True, immutable: bool = False,