From 2ffd7aa5e4ef1f6808bddd3c2689416c5dc79a98 Mon Sep 17 00:00:00 2001 From: Lena Kashtelyan Date: Mon, 10 Feb 2025 13:05:54 -0800 Subject: [PATCH] Move `TrialStatus` out to its own file (preliminary cleanup) (#3325) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3325 Reviewed By: saitcakmak Differential Revision: D69223588 --- .../tests/test_can_generate_candidates.py | 2 +- ax/benchmark/benchmark.py | 2 +- ax/benchmark/tests/test_benchmark_runner.py | 2 +- ax/core/base_trial.py | 140 +--------------- ax/core/experiment.py | 12 +- ax/core/observation.py | 2 +- ax/core/tests/test_batch_trial.py | 2 +- ax/core/tests/test_observation.py | 2 +- ax/core/tests/test_utils.py | 2 +- ax/core/trial_status.py | 149 ++++++++++++++++++ ax/early_stopping/strategies/base.py | 2 +- ax/early_stopping/tests/test_strategies.py | 2 +- ax/early_stopping/utils.py | 2 +- ax/generation_strategy/generation_node.py | 2 +- .../generation_node_input_constructors.py | 2 +- ax/generation_strategy/generation_strategy.py | 2 +- .../tests/test_aepsych_criterion.py | 2 +- .../tests/test_generation_node.py | 4 +- .../tests/test_generation_strategy.py | 2 +- .../tests/test_transition_criterion.py | 2 +- .../transition_criterion.py | 4 +- ax/global_stopping/strategies/base.py | 3 +- ax/global_stopping/tests/test_strategies.py | 2 +- ax/modelbridge/base.py | 2 +- ax/modelbridge/map_torch.py | 2 +- .../tests/test_map_torch_modelbridge.py | 4 +- ax/preview/api/client.py | 2 +- ax/preview/api/protocols/runner.py | 2 +- ax/preview/api/tests/test_client.py | 2 +- ax/preview/modelbridge/dispatch_utils.py | 2 +- .../tests/test_preview_dispatch_utils.py | 2 +- .../tests/test_single_running_trial_mixin.py | 2 +- .../tests/test_with_db_settings_base.py | 3 +- ax/service/utils/report_utils.py | 2 +- ax/storage/json_store/decoders.py | 2 +- ax/storage/json_store/registry.py | 2 +- ax/storage/sqa_store/sqa_classes.py | 3 +- ax/storage/sqa_store/tests/test_sqa_store.py | 2 +- ax/utils/testing/backend_simulator.py | 2 +- ax/utils/testing/modeling_stubs.py | 2 +- .../testing/tests/test_backend_simulator.py | 2 +- sphinx/source/core.rst | 9 ++ 42 files changed, 209 insertions(+), 186 deletions(-) create mode 100644 ax/core/trial_status.py diff --git a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py index 6ccd9754308..3ecd71faf23 100644 --- a/ax/analysis/healthcheck/tests/test_can_generate_candidates.py +++ b/ax/analysis/healthcheck/tests/test_can_generate_candidates.py @@ -13,7 +13,7 @@ CanGenerateCandidatesAnalysis, ) from ax.analysis.healthcheck.healthcheck_analysis import HealthcheckStatus -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from pandas import testing as pdt diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index 75e22899f77..c1a58be0534 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -35,11 +35,11 @@ from ax.benchmark.benchmark_test_function import BenchmarkTestFunction from ax.benchmark.methods.sobol import get_sobol_generation_strategy from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.objective import MultiObjective from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus from ax.core.types import TParameterization, TParamValue from ax.core.utils import get_model_times from ax.service.scheduler import Scheduler diff --git a/ax/benchmark/tests/test_benchmark_runner.py b/ax/benchmark/tests/test_benchmark_runner.py index 3b0698bda28..c21bd5d37d7 100644 --- a/ax/benchmark/tests/test_benchmark_runner.py +++ b/ax/benchmark/tests/test_benchmark_runner.py @@ -25,11 +25,11 @@ Jenatton, ) from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.experiment import Experiment from ax.core.search_space import SearchSpace from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.testing.benchmark_stubs import ( diff --git a/ax/core/base_trial.py b/ax/core/base_trial.py index 8cefd4cec38..a5945b68043 100644 --- a/ax/core/base_trial.py +++ b/ax/core/base_trial.py @@ -12,7 +12,6 @@ from collections.abc import Callable from copy import deepcopy from datetime import datetime, timedelta -from enum import Enum from typing import Any, TYPE_CHECKING from ax.core.arm import Arm @@ -23,6 +22,7 @@ from ax.core.map_metric import MapMetric from ax.core.metric import Metric, MetricFetchResult from ax.core.runner import Runner +from ax.core.trial_status import TrialStatus from ax.core.types import TCandidateMetadata, TEvaluationOutcome from ax.exceptions.core import UnsupportedError from ax.utils.common.base import SortableBase @@ -34,144 +34,6 @@ from ax import core # noqa F401 -class TrialStatus(int, Enum): - """Enum of trial status. - - General lifecycle of a trial is::: - - CANDIDATE --> STAGED --> RUNNING --> COMPLETED - -------------> --> FAILED (retryable) - --> EARLY_STOPPED (deemed unpromising) - -------------------------> ABANDONED (non-retryable) - - Trial is marked as a ``CANDIDATE`` immediately upon its creation. - - Trials may be abandoned at any time prior to completion or failure. - The difference between abandonment and failure is that the ``FAILED`` state - is meant to express a possibly transient or retryable error, so trials in - that state may be re-run and arm(s) in them may be resuggested by Ax models - to be added to new trials. - - ``ABANDONED`` trials on the other end, indicate - that the trial (and arms(s) in it) should not be rerun or added to new - trials. A trial might be marked ``ABANDONED`` as a result of human-initiated - action (if some trial in experiment is poorly-performing, deterministically - failing etc., and should not be run again in the experiment). It might also - be marked ``ABANDONED`` in an automated way if the trial's execution - encounters an error that indicates that the arm(s) in the trial should bot - be evaluated in the experiment again (e.g. the parameterization in a given - arm deterministically causes trial evaluation to fail). Note that it's also - possible to abandon a single arm in a `BatchTrial` via - ``batch.mark_arm_abandoned``. - - Early-stopped refers to trials that were deemed - unpromising by an early-stopping strategy and therefore terminated. - - Additionally, when trials are deployed, they may be in an intermediate - staged state (e.g. scheduled but waiting for resources) or immediately - transition to running. Note that ``STAGED`` trial status is not always - applicable and depends on the ``Runner`` trials are deployed with - (and whether a ``Runner`` is present at all; for example, in Ax Service - API, trials are marked as ``RUNNING`` immediately when generated from - ``get_next_trial``, skipping the ``STAGED`` status). - - NOTE: Data for abandoned trials (or abandoned arms in batch trials) is - not passed to the model as part of training data, unless ``fit_abandoned`` - option is specified to model bridge. Additionally, data from MapMetrics is - typically excluded unless the corresponding trial is completed. - """ - - CANDIDATE = 0 - STAGED = 1 - FAILED = 2 - COMPLETED = 3 - RUNNING = 4 - ABANDONED = 5 - DISPATCHED = 6 # Deprecated. - EARLY_STOPPED = 7 - - @property - def is_terminal(self) -> bool: - """True if trial is completed.""" - return ( - self == TrialStatus.ABANDONED - or self == TrialStatus.COMPLETED - or self == TrialStatus.FAILED - or self == TrialStatus.EARLY_STOPPED - ) - - @property - def expecting_data(self) -> bool: - """True if trial is expecting data.""" - return self in STATUSES_EXPECTING_DATA - - @property - def is_deployed(self) -> bool: - """True if trial has been deployed but not completed.""" - return self == TrialStatus.STAGED or self == TrialStatus.RUNNING - - @property - def is_failed(self) -> bool: - """True if this trial is a failed one.""" - return self == TrialStatus.FAILED - - @property - def is_abandoned(self) -> bool: - """True if this trial is an abandoned one.""" - return self == TrialStatus.ABANDONED - - @property - def is_candidate(self) -> bool: - """True if this trial is a candidate.""" - return self == TrialStatus.CANDIDATE - - @property - def is_completed(self) -> bool: - """True if this trial is a successfully completed one.""" - return self == TrialStatus.COMPLETED - - @property - def is_running(self) -> bool: - """True if this trial is a running one.""" - return self == TrialStatus.RUNNING - - @property - def is_early_stopped(self) -> bool: - """True if this trial is an early stopped one.""" - return self == TrialStatus.EARLY_STOPPED - - def __format__(self, fmt: str) -> str: - """Define `__format__` to avoid pulling the `__format__` from the `int` - mixin (since its better for statuses to show up as `RUNNING` than as - just an int that is difficult to interpret). - - E.g. batch trial representation with the overridden method is: - "BatchTrial(experiment_name='test', index=0, status=TrialStatus.CANDIDATE)". - - Docs on enum formatting: https://docs.python.org/3/library/enum.html#others. - """ - return f"{self!s}" - - def __repr__(self) -> str: - return f"{self.__class__}.{self.name}" - - -DEFAULT_STATUSES_TO_WARM_START: list[TrialStatus] = [ - TrialStatus.RUNNING, - TrialStatus.COMPLETED, - TrialStatus.ABANDONED, - TrialStatus.EARLY_STOPPED, -] - -NON_ABANDONED_STATUSES: set[TrialStatus] = set(TrialStatus) - {TrialStatus.ABANDONED} - -STATUSES_EXPECTING_DATA: list[TrialStatus] = [ - TrialStatus.RUNNING, - TrialStatus.COMPLETED, - TrialStatus.EARLY_STOPPED, -] - - def immutable_once_run(func: Callable) -> Callable: """Decorator for methods that should throw Error when trial is running or has ever run and immutable. diff --git a/ax/core/experiment.py b/ax/core/experiment.py index b2509a11047..9eea5b7e6a6 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -22,12 +22,7 @@ import pandas as pd from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose -from ax.core.base_trial import ( - BaseTrial, - DEFAULT_STATUSES_TO_WARM_START, - STATUSES_EXPECTING_DATA, - TrialStatus, -) +from ax.core.base_trial import BaseTrial from ax.core.batch_trial import BatchTrial, LifecycleStage from ax.core.data import Data from ax.core.formatting_utils import DATA_TYPE_LOOKUP, DataType @@ -41,6 +36,11 @@ from ax.core.runner import Runner from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.core.trial import Trial +from ax.core.trial_status import ( + DEFAULT_STATUSES_TO_WARM_START, + STATUSES_EXPECTING_DATA, + TrialStatus, +) from ax.core.types import ComparisonOp, TParameterization from ax.exceptions.core import ( AxError, diff --git a/ax/core/observation.py b/ax/core/observation.py index 45e6c186ea1..c6269a59a7a 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -19,11 +19,11 @@ import numpy.typing as npt import pandas as pd from ax.core.arm import Arm -from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.map_data import MapData from ax.core.map_metric import MapMetric +from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus from ax.core.types import TCandidateMetadata, TParameterization from ax.utils.common.base import Base from ax.utils.common.constants import Keys diff --git a/ax/core/tests/test_batch_trial.py b/ax/core/tests/test_batch_trial.py index cfd2ef175af..1f625735cee 100644 --- a/ax/core/tests/test_batch_trial.py +++ b/ax/core/tests/test_batch_trial.py @@ -13,12 +13,12 @@ import numpy as np from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial, GeneratorRunStruct from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun, GeneratorRunType from ax.core.parameter import FixedParameter, ParameterType from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 2c304353502..459bcfbdfa3 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -12,7 +12,6 @@ import numpy as np import pandas as pd from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.generator_run import GeneratorRun @@ -29,6 +28,7 @@ separate_observations, ) from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.core.types import TParameterization from ax.utils.common.testutils import TestCase from pyre_extensions import none_throws diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index 7ff4980010e..9798ab8de1e 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -12,7 +12,6 @@ import numpy as np import pandas as pd from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric @@ -20,6 +19,7 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint +from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.core.utils import ( best_feasible_objective, diff --git a/ax/core/trial_status.py b/ax/core/trial_status.py new file mode 100644 index 00000000000..f4c0e8d0da8 --- /dev/null +++ b/ax/core/trial_status.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from enum import Enum + + +class TrialStatus(int, Enum): + """Enum of trial status. + + General lifecycle of a trial is::: + + CANDIDATE --> STAGED --> RUNNING --> COMPLETED + -------------> --> FAILED (retryable) + --> EARLY_STOPPED (deemed unpromising) + -------------------------> ABANDONED (non-retryable) + + Trial is marked as a ``CANDIDATE`` immediately upon its creation. + + Trials may be abandoned at any time prior to completion or failure. + The difference between abandonment and failure is that the ``FAILED`` state + is meant to express a possibly transient or retryable error, so trials in + that state may be re-run and arm(s) in them may be resuggested by Ax models + to be added to new trials. + + ``ABANDONED`` trials on the other end, indicate + that the trial (and arms(s) in it) should not be rerun or added to new + trials. A trial might be marked ``ABANDONED`` as a result of human-initiated + action (if some trial in experiment is poorly-performing, deterministically + failing etc., and should not be run again in the experiment). It might also + be marked ``ABANDONED`` in an automated way if the trial's execution + encounters an error that indicates that the arm(s) in the trial should bot + be evaluated in the experiment again (e.g. the parameterization in a given + arm deterministically causes trial evaluation to fail). Note that it's also + possible to abandon a single arm in a `BatchTrial` via + ``batch.mark_arm_abandoned``. + + Early-stopped refers to trials that were deemed + unpromising by an early-stopping strategy and therefore terminated. + + Additionally, when trials are deployed, they may be in an intermediate + staged state (e.g. scheduled but waiting for resources) or immediately + transition to running. Note that ``STAGED`` trial status is not always + applicable and depends on the ``Runner`` trials are deployed with + (and whether a ``Runner`` is present at all; for example, in Ax Service + API, trials are marked as ``RUNNING`` immediately when generated from + ``get_next_trial``, skipping the ``STAGED`` status). + + NOTE: Data for abandoned trials (or abandoned arms in batch trials) is + not passed to the model as part of training data, unless ``fit_abandoned`` + option is specified to model bridge. Additionally, data from MapMetrics is + typically excluded unless the corresponding trial is completed. + """ + + CANDIDATE = 0 + STAGED = 1 + FAILED = 2 + COMPLETED = 3 + RUNNING = 4 + ABANDONED = 5 + DISPATCHED = 6 # Deprecated. + EARLY_STOPPED = 7 + + @property + def is_terminal(self) -> bool: + """True if trial is completed.""" + return ( + self == TrialStatus.ABANDONED + or self == TrialStatus.COMPLETED + or self == TrialStatus.FAILED + or self == TrialStatus.EARLY_STOPPED + ) + + @property + def expecting_data(self) -> bool: + """True if trial is expecting data.""" + return self in STATUSES_EXPECTING_DATA + + @property + def is_deployed(self) -> bool: + """True if trial has been deployed but not completed.""" + return self == TrialStatus.STAGED or self == TrialStatus.RUNNING + + @property + def is_failed(self) -> bool: + """True if this trial is a failed one.""" + return self == TrialStatus.FAILED + + @property + def is_abandoned(self) -> bool: + """True if this trial is an abandoned one.""" + return self == TrialStatus.ABANDONED + + @property + def is_candidate(self) -> bool: + """True if this trial is a candidate.""" + return self == TrialStatus.CANDIDATE + + @property + def is_completed(self) -> bool: + """True if this trial is a successfully completed one.""" + return self == TrialStatus.COMPLETED + + @property + def is_running(self) -> bool: + """True if this trial is a running one.""" + return self == TrialStatus.RUNNING + + @property + def is_early_stopped(self) -> bool: + """True if this trial is an early stopped one.""" + return self == TrialStatus.EARLY_STOPPED + + def __format__(self, fmt: str) -> str: + """Define `__format__` to avoid pulling the `__format__` from the `int` + mixin (since its better for statuses to show up as `RUNNING` than as + just an int that is difficult to interpret). + + E.g. batch trial representation with the overridden method is: + "BatchTrial(experiment_name='test', index=0, status=TrialStatus.CANDIDATE)". + + Docs on enum formatting: https://docs.python.org/3/library/enum.html#others. + """ + return f"{self!s}" + + def __repr__(self) -> str: + return f"{self.__class__}.{self.name}" + + +DEFAULT_STATUSES_TO_WARM_START: list[TrialStatus] = [ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.ABANDONED, + TrialStatus.EARLY_STOPPED, +] + +NON_ABANDONED_STATUSES: set[TrialStatus] = set(TrialStatus) - {TrialStatus.ABANDONED} + +STATUSES_EXPECTING_DATA: list[TrialStatus] = [ + TrialStatus.RUNNING, + TrialStatus.COMPLETED, + TrialStatus.EARLY_STOPPED, +] diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index cc0222121f1..ca402ca82bc 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -14,12 +14,12 @@ import numpy.typing as npt import pandas as pd -from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.map_data import MapData from ax.core.map_metric import MapMetric from ax.core.objective import MultiObjective +from ax.core.trial_status import TrialStatus from ax.early_stopping.utils import estimate_early_stopping_savings from ax.modelbridge.map_torch import MapTorchAdapter diff --git a/ax/early_stopping/tests/test_strategies.py b/ax/early_stopping/tests/test_strategies.py index 68af9378d5a..313bafa00c6 100644 --- a/ax/early_stopping/tests/test_strategies.py +++ b/ax/early_stopping/tests/test_strategies.py @@ -11,10 +11,10 @@ import numpy as np from ax.core import OptimizationConfig -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.map_data import MapData from ax.core.objective import MultiObjective +from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies import ( BaseEarlyStoppingStrategy, ModelBasedEarlyStoppingStrategy, diff --git a/ax/early_stopping/utils.py b/ax/early_stopping/utils.py index dfacd55ccfe..26c410c6e18 100644 --- a/ax/early_stopping/utils.py +++ b/ax/early_stopping/utils.py @@ -10,9 +10,9 @@ from logging import Logger import pandas as pd -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.map_data import MapData +from ax.core.trial_status import TrialStatus from ax.utils.common.logger import get_logger from pyre_extensions import assert_is_instance diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index e62b6f813fb..ae84f452467 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -14,13 +14,13 @@ from typing import Any, TYPE_CHECKING from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints from ax.generation_strategy.best_model_selector import BestModelSelector diff --git a/ax/generation_strategy/generation_node_input_constructors.py b/ax/generation_strategy/generation_node_input_constructors.py index 881fa299414..c2089d7ca14 100644 --- a/ax/generation_strategy/generation_node_input_constructors.py +++ b/ax/generation_strategy/generation_node_input_constructors.py @@ -11,8 +11,8 @@ from typing import Any from ax.core import ObservationFeatures -from ax.core.base_trial import STATUSES_EXPECTING_DATA from ax.core.experiment import Experiment +from ax.core.trial_status import STATUSES_EXPECTING_DATA from ax.core.utils import get_target_trial_index from ax.exceptions.generation_strategy import AxGenerationException diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index dcea426bc6e..0b04daaced8 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -16,12 +16,12 @@ from typing import Any, TypeVar import pandas as pd -from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures +from ax.core.trial_status import TrialStatus from ax.core.utils import extend_pending_observations, extract_pending_observations from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError from ax.exceptions.generation_strategy import ( diff --git a/ax/generation_strategy/tests/test_aepsych_criterion.py b/ax/generation_strategy/tests/test_aepsych_criterion.py index 6d488fa3615..1eb92b84d8b 100644 --- a/ax/generation_strategy/tests/test_aepsych_criterion.py +++ b/ax/generation_strategy/tests/test_aepsych_criterion.py @@ -8,8 +8,8 @@ from unittest.mock import patch import pandas as pd -from ax.core.base_trial import TrialStatus from ax.core.data import Data +from ax.core.trial_status import TrialStatus from ax.generation_strategy.generation_strategy import ( GenerationStep, GenerationStrategy, diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 3be75b01ffb..3e447773ee3 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -10,9 +10,9 @@ from unittest.mock import MagicMock, patch import torch - -from ax.core.base_trial import TrialStatus from ax.core.observation import ObservationFeatures + +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError from ax.generation_strategy.best_model_selector import ( ReductionCriterion, diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 3ef13f543f0..dea4f0b7e14 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -13,12 +13,12 @@ import numpy as np from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, ParameterType from ax.core.search_space import HierarchicalSearchSpace, SearchSpace +from ax.core.trial_status import TrialStatus from ax.core.utils import ( get_pending_observation_features_based_on_trial_status as get_pending, ) diff --git a/ax/generation_strategy/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py index a62a7fb9d03..6cdffa9f113 100644 --- a/ax/generation_strategy/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -12,8 +12,8 @@ import pandas as pd from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose -from ax.core.base_trial import TrialStatus from ax.core.data import Data +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError from ax.generation_strategy.generation_strategy import ( GenerationNode, diff --git a/ax/generation_strategy/transition_criterion.py b/ax/generation_strategy/transition_criterion.py index 79cc279c49a..f7236e6c0d6 100644 --- a/ax/generation_strategy/transition_criterion.py +++ b/ax/generation_strategy/transition_criterion.py @@ -15,9 +15,9 @@ from ax.core import MultiObjectiveOptimizationConfig from ax.core.auxiliary import AuxiliaryExperimentPurpose - -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment + +from ax.core.trial_status import TrialStatus from ax.exceptions.core import DataRequiredError, UserInputError from ax.exceptions.generation_strategy import MaxParallelismReachedException diff --git a/ax/global_stopping/strategies/base.py b/ax/global_stopping/strategies/base.py index 80cae9e68cc..b1f638f95b5 100644 --- a/ax/global_stopping/strategies/base.py +++ b/ax/global_stopping/strategies/base.py @@ -9,8 +9,9 @@ from abc import ABC, abstractmethod from typing import Any -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment + +from ax.core.trial_status import TrialStatus from ax.utils.common.base import Base diff --git a/ax/global_stopping/tests/test_strategies.py b/ax/global_stopping/tests/test_strategies.py index 4d9adad25fb..0679333e46a 100644 --- a/ax/global_stopping/tests/test_strategies.py +++ b/ax/global_stopping/tests/test_strategies.py @@ -10,7 +10,6 @@ import numpy as np import pandas as pd from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment @@ -24,6 +23,7 @@ from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.global_stopping.strategies.improvement import ( constraint_satisfaction, diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 7c6179b73ad..c9a969907ee 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -19,7 +19,6 @@ from typing import Any from ax.core.arm import Arm -from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import extract_arm_predictions, GeneratorRun @@ -34,6 +33,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace +from ax.core.trial_status import NON_ABANDONED_STATUSES, TrialStatus from ax.core.types import ( TCandidateMetadata, TModelCov, diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index dce8ac0a947..5fdaa6a4d9d 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -11,7 +11,6 @@ import numpy.typing as npt import torch -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment @@ -25,6 +24,7 @@ ) from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus from ax.core.types import TCandidateMetadata from ax.modelbridge.base import GenResults from ax.modelbridge.modelbridge_utils import ( diff --git a/ax/modelbridge/tests/test_map_torch_modelbridge.py b/ax/modelbridge/tests/test_map_torch_modelbridge.py index 7ade1db5386..9848b5142fd 100644 --- a/ax/modelbridge/tests/test_map_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_map_torch_modelbridge.py @@ -11,13 +11,13 @@ import numpy as np import torch - -from ax.core.base_trial import TrialStatus from ax.core.observation import ( ObservationData, ObservationFeatures, recombine_observations, ) + +from ax.core.trial_status import TrialStatus from ax.modelbridge.map_torch import MapTorchAdapter from ax.models.torch_base import TorchGenerator, TorchGenResults from ax.utils.common.constants import Keys diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index f489a3664dc..8d4ad6ed150 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -20,7 +20,6 @@ markdown_analysis_card_from_analysis_e, ) from ax.analysis.utils import choose_analyses -from ax.core.base_trial import TrialStatus # Used as a return type from ax.core.experiment import Experiment from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective @@ -28,6 +27,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.runner import Runner from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus # Used as a return type from ax.core.utils import get_pending_observation_features_based_on_trial_status from ax.early_stopping.strategies import ( BaseEarlyStoppingStrategy, diff --git a/ax/preview/api/protocols/runner.py b/ax/preview/api/protocols/runner.py index b5bf99ca468..91219c0c3e8 100644 --- a/ax/preview/api/protocols/runner.py +++ b/ax/preview/api/protocols/runner.py @@ -8,7 +8,7 @@ from typing import Any, Mapping -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.preview.api.protocols.utils import _APIRunner from ax.preview.api.types import TParameterization from pyre_extensions import override diff --git a/ax/preview/api/tests/test_client.py b/ax/preview/api/tests/test_client.py index 0561bd8f5aa..16696680029 100644 --- a/ax/preview/api/tests/test_client.py +++ b/ax/preview/api/tests/test_client.py @@ -11,7 +11,6 @@ import pandas as pd from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.formatting_utils import DataType @@ -28,6 +27,7 @@ from ax.core.parameter_constraint import ParameterConstraint from ax.core.search_space import SearchSpace from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy from ax.exceptions.core import UnsupportedError from ax.preview.api.client import Client diff --git a/ax/preview/modelbridge/dispatch_utils.py b/ax/preview/modelbridge/dispatch_utils.py index d7853290d64..84d3d0453aa 100644 --- a/ax/preview/modelbridge/dispatch_utils.py +++ b/ax/preview/modelbridge/dispatch_utils.py @@ -7,7 +7,7 @@ # pyre-unsafe import torch -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UnsupportedError from ax.generation_strategy.generation_strategy import ( GenerationNode, diff --git a/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py b/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py index a655632580f..8f64c66f506 100644 --- a/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py +++ b/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py @@ -6,8 +6,8 @@ import torch -from ax.core.base_trial import TrialStatus from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.generation_strategy.transition_criterion import MinTrials from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec diff --git a/ax/runners/tests/test_single_running_trial_mixin.py b/ax/runners/tests/test_single_running_trial_mixin.py index 29b0aeb1521..83d95eafb10 100644 --- a/ax/runners/tests/test_single_running_trial_mixin.py +++ b/ax/runners/tests/test_single_running_trial_mixin.py @@ -7,7 +7,7 @@ # pyre-strict -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin from ax.runners.synthetic import SyntheticRunner from ax.utils.common.testutils import TestCase diff --git a/ax/service/tests/test_with_db_settings_base.py b/ax/service/tests/test_with_db_settings_base.py index 4fa00749fbf..b02f2b42fde 100644 --- a/ax/service/tests/test_with_db_settings_base.py +++ b/ax/service/tests/test_with_db_settings_base.py @@ -10,8 +10,9 @@ import string from unittest.mock import patch -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment + +from ax.core.trial_status import TrialStatus from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service.utils.with_db_settings_base import ( try_load_generation_strategy, diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index cb1730149c9..92f2d7ccd94 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -21,7 +21,6 @@ import numpy.typing as npt import pandas as pd import plotly.graph_objects as go -from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRunType @@ -33,6 +32,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import Parameter from ax.core.trial import BaseTrial +from ax.core.trial_status import TrialStatus from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy from ax.exceptions.core import DataRequiredError, UserInputError from ax.generation_strategy.generation_strategy import GenerationStrategy diff --git a/ax/storage/json_store/decoders.py b/ax/storage/json_store/decoders.py index d5fb01ca152..d8f233e4f15 100644 --- a/ax/storage/json_store/decoders.py +++ b/ax/storage/json_store/decoders.py @@ -17,7 +17,6 @@ import torch from ax.core.arm import Arm -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import ( AbandonedArm, BatchTrial, @@ -27,6 +26,7 @@ from ax.core.generator_run import GeneratorRun from ax.core.runner import Runner from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.exceptions.storage import JSONDecodeError from ax.modelbridge.transforms.base import Transform from ax.storage.botorch_modular_registry import ( diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 7f05d636e11..d1996f0e48f 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -23,7 +23,6 @@ from ax.core import Experiment, ObservationFeatures from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import ( AbandonedArm, BatchTrial, @@ -58,6 +57,7 @@ from ax.core.risk_measures import RiskMeasure from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace from ax.core.trial import Trial +from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.early_stopping.strategies import ( PercentileEarlyStoppingStrategy, diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index db6eb8a51fe..560f286207d 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -12,9 +12,10 @@ from decimal import Decimal from typing import Any, List -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import LifecycleStage from ax.core.parameter import ParameterType + +from ax.core.trial_status import TrialStatus from ax.core.types import ( ComparisonOp, TModelPredict, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 84984875d13..0c03cd49dd7 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -21,7 +21,6 @@ from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose -from ax.core.base_trial import TrialStatus from ax.core.batch_trial import LifecycleStage from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun @@ -30,6 +29,7 @@ from ax.core.outcome_constraint import OutcomeConstraint from ax.core.parameter import ParameterType, RangeParameter from ax.core.runner import Runner +from ax.core.trial_status import TrialStatus from ax.core.types import ComparisonOp from ax.exceptions.core import ObjectNotFoundError from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError diff --git a/ax/utils/testing/backend_simulator.py b/ax/utils/testing/backend_simulator.py index eab0d4a5812..4acd17d821f 100644 --- a/ax/utils/testing/backend_simulator.py +++ b/ax/utils/testing/backend_simulator.py @@ -13,7 +13,7 @@ from logging import Logger -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.utils.common.base import Base from ax.utils.common.logger import get_logger from pyre_extensions import none_throws diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 5eb93e38c66..710ce08ab37 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -10,12 +10,12 @@ from typing import Any import numpy as np -from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import FixedParameter, RangeParameter from ax.core.search_space import SearchSpace +from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError from ax.generation_strategy.best_model_selector import ( ReductionCriterion, diff --git a/ax/utils/testing/tests/test_backend_simulator.py b/ax/utils/testing/tests/test_backend_simulator.py index 0df49bdfa1a..cec79f8e7b0 100644 --- a/ax/utils/testing/tests/test_backend_simulator.py +++ b/ax/utils/testing/tests/test_backend_simulator.py @@ -8,7 +8,7 @@ from unittest.mock import Mock, patch -from ax.core.base_trial import TrialStatus +from ax.core.trial_status import TrialStatus from ax.utils.common.testutils import TestCase from ax.utils.testing.backend_simulator import BackendSimulator, BackendSimulatorOptions from ax.utils.testing.utils_testing_stubs import get_backend_simulator_with_trials diff --git a/sphinx/source/core.rst b/sphinx/source/core.rst index 9075d61d693..1e32a7b3b6f 100644 --- a/sphinx/source/core.rst +++ b/sphinx/source/core.rst @@ -204,6 +204,15 @@ Core Classes :show-inheritance: +`TrialStatus` +~~~~~~~~~~~~ + +.. automodule:: ax.core.trial_status + :members: + :undoc-members: + :show-inheritance: + + Core Types ----------