diff --git a/ax/core/objective.py b/ax/core/objective.py index 653c0f9afc1..151e331d8ea 100644 --- a/ax/core/objective.py +++ b/ax/core/objective.py @@ -13,6 +13,7 @@ from typing import Any, Iterable, List, Optional, Tuple from ax.core.metric import Metric +from ax.exceptions.core import UserInputError from ax.utils.common.base import SortableBase from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import not_none @@ -34,36 +35,27 @@ def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None: metric: The metric to be optimized. minimize: If True, minimize metric. If None, will be set based on the `lower_is_better` property of the metric (if that is not specified, - will raise a DeprecationWarning). + will raise a `UserInputError`). """ lower_is_better = metric.lower_is_better if minimize is None: if lower_is_better is None: - warnings.warn( - f"Defaulting to `minimize=False` for metric {metric.name} not " - + "specifying `lower_is_better` property. This is a wild guess. " - + "Specify either `lower_is_better` on the metric, or specify " - + "`minimize` explicitly. This will become an error in the future.", - DeprecationWarning, + raise UserInputError( + f"Metric {metric.name} does not specify `lower_is_better` " + "and `minimize` is not specified. At least one of these " + "must be specified." ) - minimize = False else: minimize = lower_is_better - if lower_is_better is not None: - if lower_is_better and not minimize: - warnings.warn( - f"Attempting to maximize metric {metric.name} with property " - "`lower_is_better=True`." - ) - elif not lower_is_better and minimize: - warnings.warn( - f"Attempting to minimize metric {metric.name} with property " - "`lower_is_better=False`." - ) - self._metric = metric - # pyre-fixme[4]: Attribute must be annotated. - self.minimize = not_none(minimize) + elif lower_is_better is not None and lower_is_better != minimize: + raise UserInputError( + f"Metric {metric.name} specifies {lower_is_better=}, " + "which doesn't match the specified optimization direction " + f"{minimize=}." + ) + self._metric: Metric = metric + self.minimize: bool = not_none(minimize) @property def metric(self) -> Metric: @@ -130,18 +122,17 @@ def __init__( "as input to `MultiObjective` constructor." ) metrics = extra_kwargs["metrics"] - minimize = extra_kwargs.get("minimize", False) + minimize = extra_kwargs.get("minimize", None) warnings.warn( "Passing `metrics` and `minimize` as input to the `MultiObjective` " "constructor will soon be deprecated. Instead, pass a list of " "`objectives`. This will become an error in the future.", DeprecationWarning, + stacklevel=2, ) objectives = [] for metric in metrics: - lower_is_better = metric.lower_is_better or False - _minimize = not lower_is_better if minimize else lower_is_better - objectives.append(Objective(metric=metric, minimize=_minimize)) + objectives.append(Objective(metric=metric, minimize=minimize)) # pyre-fixme[4]: Attribute must be annotated. self._objectives = not_none(objectives) diff --git a/ax/core/tests/test_objective.py b/ax/core/tests/test_objective.py index 3c6e7a3059d..39bdc0d9fc2 100644 --- a/ax/core/tests/test_objective.py +++ b/ax/core/tests/test_objective.py @@ -6,10 +6,9 @@ # pyre-strict -import warnings - from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective, ScalarizedObjective +from ax.exceptions.core import UserInputError from ax.utils.common.testutils import TestCase @@ -21,7 +20,7 @@ def setUp(self) -> None: "m3": Metric(name="m3", lower_is_better=False), } self.objectives = { - "o1": Objective(metric=self.metrics["m1"]), + "o1": Objective(metric=self.metrics["m1"], minimize=True), "o2": Objective(metric=self.metrics["m2"], minimize=True), "o3": Objective(metric=self.metrics["m3"], minimize=False), } @@ -38,6 +37,12 @@ def setUp(self) -> None: ) def test_Init(self) -> None: + with self.assertRaisesRegex(UserInputError, "does not specify"): + Objective(metric=self.metrics["m1"]), + with self.assertRaisesRegex( + UserInputError, "doesn't match the specified optimization direction" + ): + Objective(metric=self.metrics["m2"], minimize=False) with self.assertRaises(ValueError): ScalarizedObjective( metrics=[self.metrics["m1"], self.metrics["m2"]], weights=[1.0] @@ -52,20 +57,6 @@ def test_Init(self) -> None: metrics=[self.metrics["m1"], self.metrics["m2"]], minimize=False, ) - warnings.resetwarnings() - warnings.simplefilter("always", append=True) - with warnings.catch_warnings(record=True) as ws: - Objective(metric=self.metrics["m1"]) - self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws)) - self.assertTrue( - any("Defaulting to `minimize=False`" in str(w.message) for w in ws) - ) - with warnings.catch_warnings(record=True) as ws: - Objective(Metric(name="m4", lower_is_better=True), minimize=False) - self.assertTrue(any("Attempting to maximize" in str(w.message) for w in ws)) - with warnings.catch_warnings(record=True) as ws: - Objective(Metric(name="m4", lower_is_better=False), minimize=True) - self.assertTrue(any("Attempting to minimize" in str(w.message) for w in ws)) self.assertEqual( self.objective.get_unconstrainable_metrics(), [self.metrics["m1"]] ) @@ -77,7 +68,7 @@ def test_MultiObjective(self) -> None: self.assertEqual(self.multi_objective.metrics, list(self.metrics.values())) minimizes = [obj.minimize for obj in self.multi_objective.objectives] - self.assertEqual(minimizes, [False, True, False]) + self.assertEqual(minimizes, [True, True, False]) weights = [mw[1] for mw in self.multi_objective.objective_weights] self.assertEqual(weights, [1.0, 1.0, 1.0]) self.assertEqual(self.multi_objective.clone(), self.multi_objective) @@ -85,7 +76,7 @@ def test_MultiObjective(self) -> None: str(self.multi_objective), ( "MultiObjective(objectives=" - '[Objective(metric_name="m1", minimize=False), ' + '[Objective(metric_name="m1", minimize=True), ' 'Objective(metric_name="m2", minimize=True), ' 'Objective(metric_name="m3", minimize=False)])' ), @@ -96,19 +87,26 @@ def test_MultiObjective(self) -> None: ) def test_MultiObjectiveBackwardsCompatibility(self) -> None: - multi_objective = MultiObjective( - metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]] - ) + metrics = [ + Metric(name="m1", lower_is_better=False), + self.metrics["m2"], + self.metrics["m3"], + ] + multi_objective = MultiObjective(metrics=metrics) minimizes = [obj.minimize for obj in multi_objective.objectives] - self.assertEqual(multi_objective.metrics, list(self.metrics.values())) + self.assertEqual(multi_objective.metrics, metrics) self.assertEqual(minimizes, [False, True, False]) multi_objective_min = MultiObjective( - metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]], + metrics=[ + Metric(name="m1"), + Metric(name="m2"), + Metric(name="m3", lower_is_better=True), + ], minimize=True, ) minimizes = [obj.minimize for obj in multi_objective_min.objectives] - self.assertEqual(minimizes, [True, False, True]) + self.assertEqual(minimizes, [True, True, True]) def test_ScalarizedObjective(self) -> None: with self.assertRaises(NotImplementedError): diff --git a/ax/core/tests/test_optimization_config.py b/ax/core/tests/test_optimization_config.py index 30739d8c512..f3a484c279a 100644 --- a/ax/core/tests/test_optimization_config.py +++ b/ax/core/tests/test_optimization_config.py @@ -277,7 +277,7 @@ def setUp(self) -> None: "o2": Objective(metric=self.metrics["m2"], minimize=False), "o3": Objective(metric=self.metrics["m3"], minimize=False), } - self.objective = Objective(metric=self.metrics["m1"], minimize=False) + self.objective = Objective(metric=self.metrics["m1"], minimize=True) self.multi_objective = MultiObjective( objectives=[self.objectives["o1"], self.objectives["o2"]] ) diff --git a/ax/core/tests/test_utils.py b/ax/core/tests/test_utils.py index a5c5f5ab129..cede5005733 100644 --- a/ax/core/tests/test_utils.py +++ b/ax/core/tests/test_utils.py @@ -158,7 +158,7 @@ def setUp(self) -> None: self.data = Data(df=self.df) self.optimization_config = OptimizationConfig( - objective=Objective(metric=Metric(name="a")), + objective=Objective(metric=Metric(name="a"), minimize=False), outcome_constraints=[ OutcomeConstraint( metric=Metric(name="b"), diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index cb3ce7e9e53..f7ad8e5c233 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -156,7 +156,7 @@ def test_ModelBridge( observation_features=[get_observation1trans().features], weights=[2] ), ) - oc = OptimizationConfig(objective=Objective(metric=Metric(name="test_metric"))) + oc = get_optimization_config_no_constraints() modelbridge._set_kwargs_to_save( model_key="TestModel", model_kwargs={}, bridge_kwargs={} ) @@ -322,7 +322,7 @@ def warn_and_return_mock_obs( fit_tracking_metrics=False, ) new_oc = OptimizationConfig( - objective=Objective(metric=Metric(name="test_metric2")) + objective=Objective(metric=Metric(name="test_metric2"), minimize=False), ) with self.assertRaisesRegex(UnsupportedError, "fit_tracking_metrics"): modelbridge.gen(n=1, optimization_config=new_oc) diff --git a/ax/modelbridge/tests/test_cross_validation.py b/ax/modelbridge/tests/test_cross_validation.py index 7ede3cd7e89..02d5e650e05 100644 --- a/ax/modelbridge/tests/test_cross_validation.py +++ b/ax/modelbridge/tests/test_cross_validation.py @@ -344,7 +344,7 @@ def test_HasGoodOptConfigModelFit(self) -> None: # Test single objective optimization_config = OptimizationConfig( - objective=Objective(metric=Metric("a")) + objective=Objective(metric=Metric("a"), minimize=True) ) has_good_fit = has_good_opt_config_model_fit( optimization_config=optimization_config, @@ -354,7 +354,12 @@ def test_HasGoodOptConfigModelFit(self) -> None: # Test multi objective optimization_config = MultiObjectiveOptimizationConfig( - objective=MultiObjective(metrics=[Metric("a"), Metric("b")]) + objective=MultiObjective( + objectives=[ + Objective(Metric("a"), minimize=False), + Objective(Metric("b"), minimize=False), + ] + ) ) has_good_fit = has_good_opt_config_model_fit( optimization_config=optimization_config, @@ -364,7 +369,7 @@ def test_HasGoodOptConfigModelFit(self) -> None: # Test constraints optimization_config = OptimizationConfig( - objective=Objective(metric=Metric("a")), + objective=Objective(metric=Metric("a"), minimize=False), outcome_constraints=[ OutcomeConstraint(metric=Metric("b"), op=ComparisonOp.GEQ, bound=0.1) ], diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index a0466356f66..2062fac283f 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -45,6 +45,7 @@ get_branin_experiment, get_branin_search_space, get_experiment_with_observations, + get_optimization_config_no_constraints, get_search_space_for_range_value, ) from ax.utils.testing.mock import fast_botorch_optimize @@ -363,9 +364,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None: observation_features=[ ObservationFeatures(parameters={"x": 1.0, "y": 2.0}) ], - optimization_config=OptimizationConfig( - objective=Objective(metric=Metric(name="test_metric")) - ), + optimization_config=get_optimization_config_no_constraints(), ) self.assertEqual(acqf_vals, [5.0]) @@ -392,9 +391,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None: ObservationFeatures(parameters={"x": 1.0, "y": 2.0}), ObservationFeatures(parameters={"x": 1.0, "y": 2.0}), ], - optimization_config=OptimizationConfig( - objective=Objective(metric=Metric(name="test_metric")) - ), + optimization_config=get_optimization_config_no_constraints(), ) t.transform_observation_features.assert_any_call( [ObservationFeatures(parameters={"x": 1.0, "y": 2.0})], @@ -418,9 +415,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None: ObservationFeatures(parameters={"x": 1.0, "y": 2.0}), ] ], - optimization_config=OptimizationConfig( - objective=Objective(metric=Metric(name="test_metric")) - ), + optimization_config=get_optimization_config_no_constraints(), ) t.transform_observation_features.assert_any_call( [ diff --git a/ax/modelbridge/tests/test_utils.py b/ax/modelbridge/tests/test_utils.py index c90db16ba4d..3f09ae770b0 100644 --- a/ax/modelbridge/tests/test_utils.py +++ b/ax/modelbridge/tests/test_utils.py @@ -117,7 +117,9 @@ def test_extract_outcome_constraints(self) -> None: def test_extract_objective_thresholds(self) -> None: outcomes = ["m1", "m2", "m3", "m4"] objective = MultiObjective( - objectives=[Objective(metric=Metric(name)) for name in outcomes[:3]] + objectives=[ + Objective(metric=Metric(name), minimize=False) for name in outcomes[:3] + ] ) objective_thresholds = [ ObjectiveThreshold( @@ -159,7 +161,7 @@ def test_extract_objective_thresholds(self) -> None: self.assertTrue(np.isnan(obj_t[-2:]).all()) # Fails if a threshold does not have a corresponding metric. - objective2 = Objective(Metric("m1")) + objective2 = Objective(Metric("m1"), minimize=False) with self.assertRaisesRegex(ValueError, "corresponding metrics"): extract_objective_thresholds( objective_thresholds=objective_thresholds, diff --git a/ax/modelbridge/transforms/tests/test_derelativize_transform.py b/ax/modelbridge/transforms/tests/test_derelativize_transform.py index 924ee4d2f05..8a436709d3f 100644 --- a/ax/modelbridge/transforms/tests/test_derelativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_derelativize_transform.py @@ -102,7 +102,7 @@ def test_DerelativizeTransform( ) # Test with no relative constraints - objective = Objective(Metric("c")) + objective = Objective(Metric("c"), minimize=True) oc = OptimizationConfig( objective=objective, outcome_constraints=[ @@ -300,7 +300,7 @@ def test_Errors(self) -> None: observations=[], ) oc = OptimizationConfig( - objective=Objective(Metric("c")), + objective=Objective(Metric("c"), minimize=False), outcome_constraints=[ OutcomeConstraint(Metric("a"), ComparisonOp.LEQ, bound=2, relative=True) ], diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index 7489bfe707b..ee91a9e16b0 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -581,7 +581,7 @@ def test_relative_constraints( RangeParameter("y", ParameterType.FLOAT, 0, 20), ] ) - objective = Objective(Metric("c")) + objective = Objective(Metric("c"), minimize=False) # Test with relative constraint, in-design status quo oc = OptimizationConfig( diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 29fc52be056..c8a4f4b8887 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -321,7 +321,7 @@ def setUp(self) -> None: self.branin_experiment_no_impl_runner_or_metrics = Experiment( search_space=get_branin_search_space(), optimization_config=OptimizationConfig( - objective=Objective(metric=Metric(name="branin")) + objective=Objective(metric=Metric(name="branin"), minimize=False) ), name="branin_experiment_no_impl_runner_or_metrics", ) diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index 0040f76f9d1..dd97206e8c1 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -560,11 +560,11 @@ def test_get_metric_name_pairs(self) -> None: exp._optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective( objectives=[ - Objective(metric=Metric("m0")), - Objective(metric=Metric("m1")), - Objective(metric=Metric("m2")), - Objective(metric=Metric("m3")), - Objective(metric=Metric("m4")), + Objective(metric=Metric("m0"), minimize=False), + Objective(metric=Metric("m1"), minimize=False), + Objective(metric=Metric("m2"), minimize=False), + Objective(metric=Metric("m3"), minimize=False), + Objective(metric=Metric("m4"), minimize=False), ] ) ) @@ -1052,9 +1052,9 @@ def test_compare_to_baseline_moo(self) -> None: optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective( objectives=[ - Objective(metric=Metric("m0")), + Objective(metric=Metric("m0"), minimize=False), Objective(metric=Metric("m1"), minimize=True), - Objective(metric=Metric("m3")), + Objective(metric=Metric("m3"), minimize=False), ] ) ) diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index f917b0cee59..d1ef02f1e38 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -25,7 +25,6 @@ ) import gpytorch - import numpy as np import pandas as pd import plotly.graph_objects as go @@ -140,7 +139,7 @@ def _get_objective_trace_plot( plot_objective_value_vs_trial_index( exp_df=exp_df, metric_colname=metric_name, - minimize=( + minimize=not_none( optimization_config.objective.minimize if optimization_config.objective.metric.name == metric_name else experiment.metrics[metric_name].lower_is_better diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 664eb4784ba..b35a467cc3c 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -12,7 +12,7 @@ from inspect import isclass from io import StringIO from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import pandas as pd @@ -25,6 +25,7 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.multi_type_experiment import MultiTypeExperiment +from ax.core.objective import Objective from ax.core.parameter import Parameter from ax.core.parameter_constraint import ( OrderConstraint, @@ -49,16 +50,20 @@ tensor_from_json, trial_from_json, ) - from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, CORE_DECODER_REGISTRY, ) from ax.utils.common.logger import get_logger -from ax.utils.common.serialization import SerializationMixin +from ax.utils.common.serialization import ( + SerializationMixin, + TClassDecoderRegistry, + TDecoderRegistry, +) from ax.utils.common.typeutils import checked_cast, not_none from ax.utils.common.typeutils_torch import torch_type_from_str + logger: Logger = get_logger(__name__) @@ -66,13 +71,8 @@ def object_from_json( # pyre-fixme[2]: Parameter annotation cannot be `Any`. object_json: Any, - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Any: """Recursively load objects from a JSON-serializable dictionary.""" if type(object_json) in (str, int, float, bool, type(None)) or isinstance( @@ -219,6 +219,12 @@ def object_from_json( decoder_registry=decoder_registry, class_decoder_registry=class_decoder_registry, ) + elif _class == Objective: + return objective_from_json( + object_json=object_json, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) elif _class == TorchvisionBenchmarkProblem: return TorchvisionBenchmarkProblem.from_dataset_name( name=object_json["name"], @@ -289,13 +295,8 @@ def ax_class_from_json_dict( # `typing.Type` to avoid runtime subscripting errors. _class: Type, object_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Any: """Reinstantiates an Ax class registered in `DECODER_REGISTRY` from a JSON dict. @@ -314,13 +315,8 @@ def ax_class_from_json_dict( def generator_run_from_json( object_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> GeneratorRun: """Load Ax GeneratorRun from JSON.""" time_created_json = object_json.pop("time_created") @@ -359,13 +355,8 @@ def trial_transition_criteria_from_json( # avoid runtime subscripting errors. class_: Type, transition_criteria_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Optional[TransitionCriterion]: """Load Ax transition criteria that depend on Trials from JSON. @@ -389,13 +380,8 @@ def trial_transition_criteria_from_json( def search_space_from_json( search_space_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> SearchSpace: """Load a SearchSpace from JSON. @@ -422,13 +408,8 @@ def search_space_from_json( def parameter_constraints_from_json( parameter_constraint_json: List[Dict[str, Any]], parameters: List[Parameter], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> List[ParameterConstraint]: """Load ParameterConstraints from JSON. @@ -476,13 +457,8 @@ def parameter_constraints_from_json( def trials_from_json( experiment: Experiment, trials_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Dict[int, BaseTrial]: """Load Ax Trials from JSON.""" loaded_trials = {} @@ -507,13 +483,8 @@ def trials_from_json( def data_from_json( data_by_trial_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Dict[int, "OrderedDict[int, Data]"]: """Load Ax Data from JSON.""" data_by_trial = object_from_json( @@ -531,13 +502,8 @@ def data_from_json( def multi_type_experiment_from_json( object_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> MultiTypeExperiment: """Load AE MultiTypeExperiment from JSON.""" experiment_info = _get_experiment_info(object_json) @@ -585,13 +551,8 @@ def multi_type_experiment_from_json( def experiment_from_json( object_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Experiment: """Load Ax Experiment from JSON.""" experiment_info = _get_experiment_info(object_json) @@ -630,13 +591,8 @@ def _get_experiment_info(object_json: Dict[str, Any]) -> Dict[str, Any]: def _load_experiment_info( exp: Experiment, exp_info: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> None: """Loads `Experiment` object with basic information.""" exp._time_created = object_from_json( @@ -690,13 +646,8 @@ def _convert_generation_step_keys_for_backwards_compatibility( def generation_node_from_json( generation_node_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> GenerationNode: """Load GenerationNode object from JSON.""" return GenerationNode( @@ -724,13 +675,8 @@ def generation_node_from_json( def generation_step_from_json( generation_step_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> GenerationStep: """Load generation step from JSON.""" generation_step_json = _convert_generation_step_keys_for_backwards_compatibility( @@ -790,13 +736,8 @@ def generation_step_from_json( def model_spec_from_json( model_spec_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> ModelSpec: """Load ModelSpec from JSON.""" kwargs = model_spec_json.pop("model_kwargs", None) @@ -834,14 +775,9 @@ def model_spec_from_json( def generation_strategy_from_json( generation_strategy_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, experiment: Optional[Experiment] = None, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> GenerationStrategy: """Load generation strategy from JSON.""" nodes = ( @@ -890,13 +826,8 @@ def generation_strategy_from_json( def surrogate_from_list_surrogate_json( list_surrogate_json: Dict[str, Any], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Surrogate: logger.warning( "`ListSurrogate` has been deprecated. Reconstructing a `Surrogate` " @@ -970,13 +901,8 @@ def surrogate_from_list_surrogate_json( def get_input_transform_json_components( input_transforms_json: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Tuple[Optional[List[Dict[str, Any]]], Optional[Dict[str, Any]]]: if input_transforms_json is None: return None, None @@ -1003,13 +929,8 @@ def get_input_transform_json_components( def get_outcome_transform_json_components( outcome_transforms_json: Optional[List[Dict[str, Any]]], - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - decoder_registry: Dict[str, Type] = CORE_DECODER_REGISTRY, - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - class_decoder_registry: Dict[ - str, Callable[[Dict[str, Any]], Any] - ] = CORE_CLASS_DECODER_REGISTRY, + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, ) -> Tuple[Optional[List[Dict[str, Any]]], Optional[Dict[str, Any]]]: if outcome_transforms_json is None: return None, None @@ -1030,3 +951,39 @@ def get_outcome_transform_json_components( for outcome_transform_json in outcome_transforms_json } return outcome_transform_classes_json, outcome_transform_options_json + + +def objective_from_json( + object_json: Dict[str, Any], + decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, + class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, +) -> Objective: + """Load an ``Objective`` from JSON in a backwards compatible way. + + If both ``minimize`` and ``lower_is_better`` are specified but have conflicting + values, this will overwrite ``lower_is_better=minimize`` to resolve the conflict. + + # TODO: Do we need to do this for scalarized objective as well? + """ + input_args = { + k: object_from_json( + v, + decoder_registry=decoder_registry, + class_decoder_registry=class_decoder_registry, + ) + for k, v in object_json.items() + } + metric = input_args.pop("metric") + minimize = input_args.pop("minimize") + if metric.lower_is_better is not None and metric.lower_is_better != minimize: + logger.warning( + f"Metric {metric.name} has {metric.lower_is_better=} but objective " + f"specifies {minimize=}. Overwriting ``lower_is_better`` to match " + f"the optimization direction {minimize=}." + ) + metric.lower_is_better = minimize + return Objective( + metric=metric, + minimize=minimize, + **input_args, # For future compatibility. + ) diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 8ca5b520b2d..a6b73277aa4 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -14,6 +14,7 @@ import torch from ax.benchmark.metrics.jenatton import JenattonMetric from ax.core.metric import Metric +from ax.core.objective import Objective from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning from ax.exceptions.storage import JSONDecodeError, JSONEncodeError @@ -652,3 +653,17 @@ def test_BadStateDict(self) -> None: expected_json = botorch_component_to_dict(interval) del expected_json["state_dict"]["lower_bound"] botorch_component_from_json(interval.__class__, expected_json) + + def test_objective_backwards_compatibility(self) -> None: + # Test that we can load an objective that has conflicting + # ``lower_is_better`` and ``minimize`` fields. + objective = get_objective(minimize=True) + objective.metric.lower_is_better = False # for conflict! + objective_json = object_to_json(objective) + self.assertTrue(objective_json["minimize"]) + self.assertFalse(objective_json["metric"]["lower_is_better"]) + objective_loaded = object_from_json(objective_json) + self.assertIsInstance(objective_loaded, Objective) + self.assertNotEqual(objective, objective_loaded) + self.assertTrue(objective_loaded.minimize) + self.assertTrue(objective_loaded.metric.lower_is_better) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 66d703e274e..38ed9a950c4 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -1028,7 +1028,16 @@ def _objective_from_sqa(self, metric: Metric, metric_sqa: SQAMetric) -> Objectiv f"The metric {metric.name} corresponding to regular objective does not " "have weight attribute" ) - return Objective(metric=metric, minimize=metric_sqa.minimize) + # Resolve any conflicts between ``lower_is_better`` and ``minimize``. + minimize = metric_sqa.minimize + if metric.lower_is_better is not None and metric.lower_is_better != minimize: + logger.warning( + f"Metric {metric.name} has {metric.lower_is_better=} but objective " + f"specifies {minimize=}. Overwriting ``lower_is_better`` to match " + f"the optimization direction {minimize=}." + ) + metric.lower_is_better = minimize + return Objective(metric=metric, minimize=minimize) def _multi_objective_from_sqa(self, parent_metric_sqa: SQAMetric) -> Objective: try: @@ -1054,9 +1063,9 @@ def _multi_objective_from_sqa(self, parent_metric_sqa: SQAMetric) -> Objective: # Extracting metric and weight for each child objectives = [ - Objective( + self._objective_from_sqa( metric=self._metric_from_sqa_util(parent_metric_sqa), - minimize=parent_metric_sqa.minimize, + metric_sqa=parent_metric_sqa, ) for parent_metric_sqa in metrics_sqa_children ] diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 498dedfe2b4..fd104292f27 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -751,7 +751,9 @@ def test_ExperimentObjectiveUpdates(self) -> None: # replace objective # (old one should become tracking metric) - optimization_config.objective = Objective(metric=Metric(name="objective")) + optimization_config.objective = Objective( + metric=Metric(name="objective"), minimize=False + ) experiment.optimization_config = optimization_config save_experiment(experiment) self.assertEqual( diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 37343452497..aa3ec939821 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -75,6 +75,7 @@ TModelPredict, TModelPredictArm, TParameterization, + TParamValue, ) from ax.early_stopping.strategies import ( BaseEarlyStoppingStrategy, @@ -439,7 +440,7 @@ def get_test_map_data_experiment( def get_multi_type_experiment( add_trial_type: bool = True, add_trials: bool = False, num_arms: int = 10 ) -> MultiTypeExperiment: - oc = OptimizationConfig(Objective(BraninMetric("m1", ["x1", "x2"]))) + oc = OptimizationConfig(Objective(BraninMetric("m1", ["x1", "x2"]), minimize=True)) experiment = MultiTypeExperiment( name="test_exp", search_space=get_branin_search_space(), @@ -505,7 +506,9 @@ def get_factorial_experiment( name="factorial_test_experiment", search_space=get_factorial_search_space(), optimization_config=( - OptimizationConfig(objective=Objective(metric=get_factorial_metric())) + OptimizationConfig( + objective=Objective(metric=get_factorial_metric(), minimize=False) + ) if has_optimization_config else None ), @@ -1489,21 +1492,12 @@ def get_augmented_hartmann_metric( def get_factorial_metric(name: str = "success_metric") -> FactorialMetric: - coefficients = { + coefficients: Dict[str, Dict[TParamValue, float]] = { "factor1": {"level11": 0.1, "level12": 0.2, "level13": 0.3}, "factor2": {"level21": 0.1, "level22": 0.2}, "factor3": {"level31": 0.1, "level32": 0.2, "level33": 0.3, "level34": 0.4}, } - return FactorialMetric( - name=name, - # Expected `Dict[str, Dict[typing.Optional[typing.Union[bool, float, str]], - # float]]` for 3rd parameter `coefficients` to call - # `ax.metrics.factorial.FactorialMetric.__init__` but got `Dict[str, - # Dict[str, float]]`. - # pyre-fixme[6]: - coefficients=coefficients, - batch_size=int(1e4), - ) + return FactorialMetric(name=name, coefficients=coefficients, batch_size=int(1e4)) def get_dict_lookup_metric() -> DictLookupMetric: @@ -1556,18 +1550,18 @@ def get_branin_outcome_constraint() -> OutcomeConstraint: ############################## -def get_objective() -> Objective: - return Objective(metric=Metric(name="m1"), minimize=False) +def get_objective(minimize: bool = False) -> Objective: + return Objective(metric=Metric(name="m1"), minimize=minimize) -def get_map_objective() -> Objective: - return Objective(metric=MapMetric(name="m1"), minimize=False) +def get_map_objective(minimize: bool = False) -> Objective: + return Objective(metric=MapMetric(name="m1"), minimize=minimize) def get_multi_objective() -> Objective: return MultiObjective( objectives=[ - Objective(metric=Metric(name="m1")), + Objective(metric=Metric(name="m1"), minimize=False), Objective(metric=Metric(name="m3", lower_is_better=True), minimize=True), ], ) @@ -1576,7 +1570,10 @@ def get_multi_objective() -> Objective: def get_custom_multi_objective() -> Objective: return MultiObjective( objectives=[ - Objective(metric=CustomTestMetric(name="m1", test_attribute="test")), + Objective( + metric=CustomTestMetric(name="m1", test_attribute="test"), + minimize=False, + ), Objective( metric=CustomTestMetric( name="m3", lower_is_better=True, test_attribute="test" @@ -1606,7 +1603,9 @@ def get_scalarized_objective() -> Objective: def get_branin_objective(name: str = "branin", minimize: bool = False) -> Objective: - return Objective(metric=get_branin_metric(name=name), minimize=minimize) + return Objective( + metric=get_branin_metric(name=name, lower_is_better=minimize), minimize=minimize + ) def get_branin_multi_objective(num_objectives: int = 2) -> Objective: @@ -1677,8 +1676,12 @@ def get_multi_objective_optimization_config( ) -def get_optimization_config_no_constraints() -> OptimizationConfig: - return OptimizationConfig(objective=Objective(metric=Metric("test_metric"))) +def get_optimization_config_no_constraints( + minimize: bool = False, +) -> OptimizationConfig: + return OptimizationConfig( + objective=Objective(metric=Metric("test_metric"), minimize=minimize) + ) def get_branin_optimization_config(minimize: bool = False) -> OptimizationConfig: diff --git a/tutorials/factorial.ipynb b/tutorials/factorial.ipynb index df7ad16d80f..63c559635d6 100644 --- a/tutorials/factorial.ipynb +++ b/tutorials/factorial.ipynb @@ -270,7 +270,7 @@ " name=\"my_factorial_closed_loop_experiment\",\n", " search_space=search_space,\n", " optimization_config=OptimizationConfig(\n", - " objective=Objective(metric=FactorialMetric(name=\"success_metric\"))\n", + " objective=Objective(metric=FactorialMetric(name=\"success_metric\"), minimize=False)\n", " ),\n", " runner=MyRunner(),\n", ")\n",