From e5d25df6f9d9ab82739e7888e890617acb61de4c Mon Sep 17 00:00:00 2001 From: Andy Lin Date: Fri, 1 Nov 2024 15:49:00 -0700 Subject: [PATCH] Support MultiTypeExperiment in Instantiation (#2939) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2939 1. **InstatntiationBase:** Add support returning MultiTypeExperiment in InstatntiationBase._make_experiment. 2. **MultiTypeExperiment:** Add add_tracking_metrics function in MultiTypeExperiment to support batch adding metrics when creating a MultiTypeExperiment. 3. **AxClient**: Add support for creating MultiTypeExperiment, add_trial_type and add_tracking_metrics. Reviewed By: sdaulton Differential Revision: D64612495 --- ax/core/multi_type_experiment.py | 42 ++++++++ ax/core/tests/test_multi_type_experiment.py | 34 ++++++ ax/service/ax_client.py | 49 +++++++-- ax/service/tests/test_ax_client.py | 106 ++++++++++++++++++- ax/service/tests/test_instantiation_utils.py | 19 ++++ ax/service/utils/instantiation.py | 32 ++++++ 6 files changed, 273 insertions(+), 9 deletions(-) diff --git a/ax/core/multi_type_experiment.py b/ax/core/multi_type_experiment.py index 5b601617695..23cf6eb87ee 100644 --- a/ax/core/multi_type_experiment.py +++ b/ax/core/multi_type_experiment.py @@ -49,6 +49,7 @@ def __init__( default_trial_type: str, default_runner: Runner, optimization_config: OptimizationConfig | None = None, + tracking_metrics: list[Metric] | None = None, status_quo: Arm | None = None, description: str | None = None, is_test: bool = False, @@ -65,6 +66,7 @@ def __init__( default_runner: Default runner for trials of the default type. optimization_config: Optimization config of the experiment. tracking_metrics: Additional tracking metrics not used for optimization. + These are associated with the default trial type. runner: Default runner used for trials on this experiment. status_quo: Arm representing existing "control" arm. description: Description of the experiment. @@ -101,6 +103,7 @@ def __init__( experiment_type=experiment_type, properties=properties, default_data_type=default_data_type, + tracking_metrics=tracking_metrics, ) def add_trial_type(self, trial_type: str, runner: Runner) -> "MultiTypeExperiment": @@ -163,6 +166,45 @@ def add_tracking_metric( self._metric_to_canonical_name[metric.name] = canonical_name return self + def add_tracking_metrics( + self, + metrics: list[Metric], + metrics_to_trial_types: dict[str, str] | None = None, + canonical_names: dict[str, str] | None = None, + ) -> Experiment: + """Add a list of new metrics to the experiment. + + If any of the metrics are already defined on the experiment, + we raise an error and don't add any of them to the experiment + + Args: + metrics: Metrics to be added. + metrics_to_trial_types: The mapping from metric names to corresponding + trial types for each metric. If provided, the metrics will be + added to their trial types. If not provided, then the default + trial type will be used. + canonical_names: A mapping of metric names to their + canonical names(The default metrics for which the metrics are + proxies.) + + Returns: + The experiment with the added metrics. + """ + metrics_to_trial_types = metrics_to_trial_types or {} + canonical_name = None + for metric in metrics: + if canonical_names is not None: + canonical_name = none_throws(canonical_names).get(metric.name, None) + + self.add_tracking_metric( + metric=metric, + trial_type=metrics_to_trial_types.get( + metric.name, self._default_trial_type + ), + canonical_name=canonical_name, + ) + return self + # pyre-fixme[14]: `update_tracking_metric` overrides method defined in # `Experiment` inconsistently. def update_tracking_metric( diff --git a/ax/core/tests/test_multi_type_experiment.py b/ax/core/tests/test_multi_type_experiment.py index 8776e4cac19..b80c73dd8fa 100644 --- a/ax/core/tests/test_multi_type_experiment.py +++ b/ax/core/tests/test_multi_type_experiment.py @@ -171,6 +171,40 @@ def test_runner_for_trial_type(self) -> None: ): self.experiment.runner_for_trial_type(trial_type="invalid") + def test_add_tracking_metrics(self) -> None: + type1_metrics = [ + BraninMetric("m3_type1", ["x1", "x2"]), + BraninMetric("m4_type1", ["x1", "x2"]), + ] + type2_metrics = [ + BraninMetric("m3_type2", ["x1", "x2"]), + BraninMetric("m4_type2", ["x1", "x2"]), + ] + default_type_metrics = [ + BraninMetric("m5_default_type", ["x1", "x2"]), + ] + self.experiment.add_tracking_metrics( + metrics=type1_metrics + type2_metrics + default_type_metrics, + metrics_to_trial_types={ + "m3_type1": "type1", + "m4_type1": "type1", + "m3_type2": "type2", + "m4_type2": "type2", + }, + ) + self.assertDictEqual( + self.experiment._metric_to_trial_type, + { + "m1": "type1", + "m2": "type2", + "m3_type1": "type1", + "m4_type1": "type1", + "m3_type2": "type2", + "m4_type2": "type2", + "m5_default_type": "type1", + }, + ) + class MultiTypeExperimentUtilsTest(TestCase): def setUp(self) -> None: diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 03ea368b4e4..905f6258454 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -27,12 +27,14 @@ from ax.core.generator_run import GeneratorRun from ax.core.map_data import MapData from ax.core.map_metric import MapMetric +from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, OptimizationConfig, ) +from ax.core.runner import Runner from ax.core.trial import Trial from ax.core.types import ( TEvaluationOutcome, @@ -40,6 +42,7 @@ TParameterization, TParamValue, ) + from ax.core.utils import get_pending_observation_features_based_on_trial_status from ax.early_stopping.strategies import BaseEarlyStoppingStrategy from ax.early_stopping.utils import estimate_early_stopping_savings @@ -90,6 +93,7 @@ from ax.utils.common.typeutils import checked_cast from pyre_extensions import assert_is_instance, none_throws + logger: Logger = get_logger(__name__) @@ -251,6 +255,8 @@ def create_experiment( immutable_search_space_and_opt_config: bool = True, is_test: bool = False, metric_definitions: dict[str, dict[str, Any]] | None = None, + default_trial_type: str | None = None, + default_runner: Runner | None = None, ) -> None: """Create a new experiment and save it if DBSettings available. @@ -316,6 +322,15 @@ def create_experiment( to that metric. Note these are modified in-place. Each Metric must have its own dictionary (metrics cannot share a single dictionary object). + default_trial_type: The default trial type if multiple + trial types are intended to be used in the experiment. If specified, + a MultiTypeExperiment will be created. Otherwise, a single-type + Experiment will be created. + default_runner: The default runner in this experiment. + This applies to MultiTypeExperiment (when default_trial_type + is specified) and needs to be specified together with + default_trial_type. This will be ignored for single-type Experiment + (when default_trial_type is not specified). """ self._validate_early_stopping_strategy(support_intermediate_data) @@ -344,6 +359,8 @@ def create_experiment( support_intermediate_data=support_intermediate_data, immutable_search_space_and_opt_config=immutable_search_space_and_opt_config, is_test=is_test, + default_trial_type=default_trial_type, + default_runner=default_runner, **objective_kwargs, ) self._set_runner(experiment=experiment) @@ -416,6 +433,8 @@ def add_tracking_metrics( self, metric_names: list[str], metric_definitions: dict[str, dict[str, Any]] | None = None, + metrics_to_trial_types: dict[str, str] | None = None, + canonical_names: dict[str, str] | None = None, ) -> None: """Add a list of new metrics to the experiment. @@ -428,20 +447,34 @@ def add_tracking_metrics( to that metric. Note these are modified in-place. Each Metric must have its is own dictionary (metrics cannot share a single dictionary object). + metrics_to_trial_types: Only applicable to MultiTypeExperiment. + The mapping from metric names to corresponding + trial types for each metric. If provided, the metrics will be + added with their respective trial types. If not provided, then the + default trial type will be used. + canonical_names: A mapping from metric name (of a particular trial type) + to the metric name of the default trial type. Only applicable to + MultiTypeExperiment. """ metric_definitions = ( self.metric_definitions if metric_definitions is None else metric_definitions ) - self.experiment.add_tracking_metrics( - metrics=[ - self._make_metric( - name=metric_name, metric_definitions=metric_definitions - ) - for metric_name in metric_names - ] - ) + metric_objects = [ + self._make_metric(name=metric_name, metric_definitions=metric_definitions) + for metric_name in metric_names + ] + + if isinstance(self.experiment, MultiTypeExperiment): + experiment = assert_is_instance(self.experiment, MultiTypeExperiment) + experiment.add_tracking_metrics( + metrics=metric_objects, + metrics_to_trial_types=metrics_to_trial_types, + canonical_names=canonical_names, + ) + else: + self.experiment.add_tracking_metrics(metrics=metric_objects) @copy_doc(Experiment.remove_tracking_metric) def remove_tracking_metric(self, metric_name: str) -> None: diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 10a9cb1020b..133ad0903e4 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -21,6 +21,7 @@ from ax.core.arm import Arm from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric +from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.optimization_config import MultiObjectiveOptimizationConfig from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint from ax.core.parameter import ( @@ -57,6 +58,7 @@ from ax.modelbridge.model_spec import ModelSpec from ax.modelbridge.random import RandomModelBridge from ax.modelbridge.registry import Models +from ax.runners.synthetic import SyntheticRunner from ax.service.ax_client import AxClient, ObjectiveProperties from ax.service.utils.best_point import ( @@ -83,7 +85,7 @@ from ax.utils.testing.mock import mock_botorch_optimize from ax.utils.testing.modeling_stubs import get_observation1, get_observation1trans from botorch.test_functions.multi_objective import BraninCurrin -from pyre_extensions import none_throws +from pyre_extensions import assert_is_instance, none_throws if TYPE_CHECKING: from ax.core.types import TTrialEvaluation @@ -821,6 +823,7 @@ def test_create_experiment(self) -> None: is_test=True, ) assert ax_client._experiment is not None + self.assertEqual(ax_client.experiment.__class__.__name__, "Experiment") self.assertEqual(ax_client._experiment, ax_client.experiment) self.assertEqual( # pyre-fixme[16]: `Optional` has no attribute `search_space`. @@ -903,6 +906,107 @@ def test_create_experiment(self) -> None: {"test_objective", "some_metric", "test_tracking_metric"}, ) + def test_create_multitype_experiment(self) -> None: + """ + Test create multitype experiment, add trial type, and add metrics to + different trial types + """ + ax_client = AxClient( + GenerationStrategy( + steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + ) + ) + ax_client.create_experiment( + name="test_experiment", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [0.001, 0.1], + "value_type": "float", + "log_scale": True, + "digits": 6, + }, + { + "name": "y", + "type": "choice", + "values": [1, 2, 3], + "value_type": "int", + "is_ordered": True, + }, + {"name": "x3", "type": "fixed", "value": 2, "value_type": "int"}, + { + "name": "x4", + "type": "range", + "bounds": [1.0, 3.0], + "value_type": "int", + }, + { + "name": "x5", + "type": "choice", + "values": ["one", "two", "three"], + "value_type": "str", + }, + { + "name": "x6", + "type": "range", + "bounds": [1.0, 3.0], + "value_type": "int", + }, + ], + objectives={"test_objective": ObjectiveProperties(minimize=True)}, + outcome_constraints=["some_metric >= 3", "some_metric <= 4.0"], + parameter_constraints=["x4 <= x6"], + tracking_metric_names=["test_tracking_metric"], + is_test=True, + default_trial_type="test_trial_type", + default_runner=SyntheticRunner(), + ) + + self.assertEqual(ax_client.experiment.__class__.__name__, "MultiTypeExperiment") + experiment = assert_is_instance(ax_client.experiment, MultiTypeExperiment) + self.assertEqual( + experiment._trial_type_to_runner["test_trial_type"].__class__.__name__, + "SyntheticRunner", + ) + self.assertEqual( + experiment._metric_to_trial_type, + { + "test_tracking_metric": "test_trial_type", + "test_objective": "test_trial_type", + "some_metric": "test_trial_type", + }, + ) + experiment.add_trial_type( + trial_type="test_trial_type_2", + runner=SyntheticRunner(), + ) + ax_client.add_tracking_metrics( + metric_names=[ + "some_metric2_type1", + "some_metric3_type1", + "some_metric4_type2", + "some_metric5_type2", + ], + metrics_to_trial_types={ + "some_metric2_type1": "test_trial_type", + "some_metric4_type2": "test_trial_type_2", + "some_metric5_type2": "test_trial_type_2", + }, + ) + self.assertEqual( + experiment._metric_to_trial_type, + { + "test_tracking_metric": "test_trial_type", + "test_objective": "test_trial_type", + "some_metric": "test_trial_type", + "some_metric2_type1": "test_trial_type", + "some_metric3_type1": "test_trial_type", + "some_metric4_type2": "test_trial_type_2", + "some_metric5_type2": "test_trial_type_2", + }, + ) + def test_create_single_objective_experiment_with_objectives_dict(self) -> None: ax_client = AxClient( GenerationStrategy( diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index bdaddf20930..6dc0b73f958 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -17,6 +17,7 @@ RangeParameter, ) from ax.core.search_space import HierarchicalSearchSpace +from ax.runners.synthetic import SyntheticRunner from ax.service.utils.instantiation import InstantiationBase from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast @@ -431,3 +432,21 @@ def test_hss(self) -> None: self.assertIsInstance(search_space, HierarchicalSearchSpace) # pyre-fixme[16]: `SearchSpace` has no attribute `_root`. self.assertEqual(search_space._root.name, "root") + + def test_make_multitype_experiment_with_default_trial_type(self) -> None: + experiment = InstantiationBase.make_experiment( + name="test_make_experiment", + parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}], + tracking_metric_names=None, + default_trial_type="test_trial_type", + default_runner=SyntheticRunner(), + ) + self.assertEqual(experiment.__class__.__name__, "MultiTypeExperiment") + + def test_make_single_type_experiment_with_no_default_trial_type(self) -> None: + experiment = InstantiationBase.make_experiment( + name="test_make_experiment", + parameters=[{"name": "x", "type": "range", "bounds": [0, 1]}], + tracking_metric_names=None, + ) + self.assertEqual(experiment.__class__.__name__, "Experiment") diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index acc0f05f7b0..f741900630a 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -18,6 +18,7 @@ from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.experiment import DataType, Experiment from ax.core.metric import Metric +from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import MultiObjective, Objective from ax.core.observation import ObservationFeatures from ax.core.optimization_config import ( @@ -40,6 +41,7 @@ ParameterConstraint, validate_constraint_parameters, ) +from ax.core.runner import Runner from ax.core.search_space import HierarchicalSearchSpace, SearchSpace from ax.core.types import ComparisonOp, TParameterization, TParamValue from ax.exceptions.core import UnsupportedError @@ -799,6 +801,8 @@ def make_experiment( immutable_search_space_and_opt_config: bool = True, auxiliary_experiments_by_purpose: None | (dict[AuxiliaryExperimentPurpose, list[AuxiliaryExperiment]]) = None, + default_trial_type: str | None = None, + default_runner: Runner | None = None, is_test: bool = False, ) -> Experiment: """Instantiation wrapper that allows for Ax `Experiment` creation @@ -854,10 +858,23 @@ def make_experiment( improve storage performance. auxiliary_experiments_by_purpose: Dictionary of auxiliary experiments for different use cases (e.g., transfer learning). + default_trial_type: The default trial type if multiple + trial types are intended to be used in the experiment. If specified, + a MultiTypeExperiment will be created. Otherwise, a single-type + Experiment will be created. + default_runner: The default runner in this experiment. + This only applies to MultiTypeExperiment (when default_trial_type + is specified). is_test: Whether this experiment will be a test experiment (useful for marking test experiments in storage etc). Defaults to False. """ + if (default_trial_type is None) != (default_runner is None): + raise ValueError( + "Must specify both default_trial_type and default_runner if " + "using a MultiTypeExperiment." + ) + status_quo_arm = None if status_quo is None else Arm(parameters=status_quo) objectives = objectives or cls._get_default_objectives() @@ -896,6 +913,21 @@ def make_experiment( if owners is not None: properties["owners"] = owners + if default_trial_type is not None: + return MultiTypeExperiment( + name=none_throws(name), + search_space=cls.make_search_space(parameters, parameter_constraints), + default_trial_type=none_throws(default_trial_type), + default_runner=none_throws(default_runner), + optimization_config=optimization_config, + tracking_metrics=tracking_metrics, + status_quo=status_quo_arm, + description=description, + is_test=is_test, + experiment_type=experiment_type, + properties=properties, + default_data_type=default_data_type, + ) return Experiment( name=name,