Skip to content

Commit

Permalink
Rework TransitionCriterion storage to remove circular dep (facebook#2320
Browse files Browse the repository at this point in the history
)

Summary:

In D52982664 I introduced a circular dep by moving some storage related logic directly onto transitioncriterion class. In the past week and half, this has been pretty annoying for folks on the team (sorry everyone!). This diff fixes that dep by:

1. Keeping all storage related logic in the storage files
2. Adding a specific check in object_from_json for ```TrialBasedCriterion``. These criterion contain lists of TrialStatuses, specifically in not_in_statuses and only_in_statuses args, which makes the standard deserialization fail to properly unpack these lists in a way that maintains the TrialStatuses type. Here we add a special method to do so.
3. All other transitioncriterion can continue using deserialize method

There is a larger question here about ideally if a class inherits from SerializationMixin and all it's fields also inherit from SerializationMixin that the deserialization should eloquently handle this. I tried to find a solution for that for a bit in this period, but i kept introducing circular deps and it seems to be a larger undertaking than i have scope for at the time.

Differential Revision: D55727618
  • Loading branch information
mgarrard authored and facebook-github-bot committed Apr 4, 2024
1 parent 1b29a78 commit ee768de
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 34 deletions.
36 changes: 2 additions & 34 deletions ax/modelbridge/transition_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import abstractmethod
from logging import Logger
from typing import Any, Dict, List, Optional, Set
from typing import List, Optional, Set

from ax.core.base_trial import TrialStatus
from ax.core.experiment import Experiment
Expand All @@ -16,13 +16,7 @@

from ax.utils.common.base import SortableBase
from ax.utils.common.logger import get_logger
from ax.utils.common.serialization import (
SerializationMixin,
serialize_init_args,
TClassDecoderRegistry,
TDecoderRegistry,
)
from ax.utils.common.typeutils import not_none
from ax.utils.common.serialization import SerializationMixin, serialize_init_args

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -146,32 +140,6 @@ def __init__(
block_gen_if_met=block_gen_if_met,
)

@classmethod
def deserialize_init_args(
cls,
args: Dict[str, Any],
decoder_registry: Optional[TDecoderRegistry] = None,
class_decoder_registry: Optional[TClassDecoderRegistry] = None,
) -> Dict[str, Any]:
"""Given a dictionary, extract the properties needed to initialize the object.
Used for storage.
"""
# import here to avoid circular import
from ax.storage.json_store.decoder import object_from_json

decoder_registry = not_none(decoder_registry)
class_decoder_registry = class_decoder_registry or {}
init_args = super().deserialize_init_args(args=args)

return {
key: object_from_json(
object_json=value,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
for key, value in init_args.items()
}

def experiment_trials_by_status(
self, experiment: Experiment, statuses: List[TrialStatus]
) -> Set[int]:
Expand Down
43 changes: 43 additions & 0 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import _decode_callables_from_references
from ax.modelbridge.transition_criterion import TransitionCriterion, TrialBasedCriterion
from ax.models.torch.botorch_modular.model import SurrogateSpec
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.storage.json_store.decoders import (
Expand Down Expand Up @@ -249,6 +250,15 @@ def object_from_json(
object_json["outcome_transform_options"] = (
outcome_transform_options_json
)
elif isclass(_class) and issubclass(_class, TrialBasedCriterion):
# TrialBasedCriterion contain a list of `TrialStatus` for args.
# This list needs to be unpacked by hand to properly retain the types.
return trial_transition_criteria_from_json(
class_=_class,
transition_criteria_json=object_json,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
elif isclass(_class) and issubclass(_class, SerializationMixin):
return _class(
**_class.deserialize_init_args(
Expand Down Expand Up @@ -343,6 +353,39 @@ def generator_run_from_json(
return generator_run


def trial_transition_criteria_from_json(
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use `typing.Type` to
# avoid runtime subscripting errors.
class_: Type,
transition_criteria_json: Dict[str, Any],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
class_decoder_registry: Dict[
str, Callable[[Dict[str, Any]], Any]
] = CORE_CLASS_DECODER_REGISTRY,
) -> Optional[TransitionCriterion]:
"""Load Ax transition criteria that depend on Trials from JSON.
Since ```TrialBasedCriterion``` contain lists of ```TrialStatus``,
the json for these criterion needs to be carefully unpacked and
re-processed via ```object_from_json``` in order to maintain correct
typing. We pass in ```class_``` in order to correctly handle all classes
which inherit from ```TrialBasedCriterion``` (ex: ```MaxTrials```).
"""
new_dict = {}
for key, value in transition_criteria_json.items():
new_val = object_from_json(
object_json=value,
decoder_registry=decoder_registry,
class_decoder_registry=class_decoder_registry,
)
new_dict[key] = new_val

return class_(**new_dict)


def search_space_from_json(
search_space_json: Dict[str, Any],
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
Expand Down
4 changes: 4 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@
get_surrogate,
get_synthetic_runner,
get_threshold_early_stopping_strategy,
get_transition_criterion_list,
get_trial,
get_trial_status,
get_winsorization_config,
)
from ax.utils.testing.modeling_stubs import (
Expand Down Expand Up @@ -218,8 +220,10 @@
("Type[Transform]", get_transform_type),
("Type[InputTransform]", get_input_transform_type),
("Type[OutcomeTransform]", get_outcome_transfrom_type),
("TransitionCriterionList", get_transition_criterion_list),
("ThresholdEarlyStoppingStrategy", get_threshold_early_stopping_strategy),
("Trial", get_trial),
("TrialStatus", get_trial_status),
("WinsorizationConfig", get_winsorization_config),
("SEBOAcquisition", get_sebo_acquisition_class),
]
Expand Down
26 changes: 26 additions & 0 deletions ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
from ax.metrics.hartmann6 import AugmentedHartmann6Metric, Hartmann6Metric
from ax.modelbridge.factory import Cont_X_trans, get_factorial, get_sobol
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transition_criterion import (
MaxGenerationParallelism,
MaxTrials,
TransitionCriterion,
)
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
from ax.models.torch.botorch_modular.sebo import SEBOAcquisition
Expand Down Expand Up @@ -154,6 +159,27 @@ def get_experiment_with_map_data_type() -> Experiment:
)


def get_trial_status() -> List[TrialStatus]:
return [TrialStatus.CANDIDATE, TrialStatus.RUNNING, TrialStatus.COMPLETED]


def get_transition_criterion_list() -> List[TransitionCriterion]:
return [
MaxTrials(
threshold=3,
only_in_statuses=[TrialStatus.RUNNING, TrialStatus.COMPLETED],
not_in_statuses=None,
),
MaxGenerationParallelism(
threshold=5,
only_in_statuses=None,
not_in_statuses=[
TrialStatus.RUNNING,
],
),
]


def get_experiment_with_custom_runner_and_metric(
constrain_search_space: bool = True,
immutable: bool = False,
Expand Down

0 comments on commit ee768de

Please sign in to comment.