diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index cf7c282029d..06c61afb5cf 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -189,6 +189,9 @@ def benchmark_replication( score_trace = np.full(len(optimization_trace), np.nan) fit_time, gen_time = get_model_times(experiment=experiment) + # Strip runner from experiment before returning, so that the experiment can + # be serialized (the runner can't be) + experiment.runner = None return BenchmarkResult( name=scheduler.experiment.name, diff --git a/ax/benchmark/runners/base.py b/ax/benchmark/runners/base.py index dfd55e814e1..772e509d419 100644 --- a/ax/benchmark/runners/base.py +++ b/ax/benchmark/runners/base.py @@ -18,6 +18,8 @@ from ax.core.search_space import SearchSpaceDigest from ax.core.trial import Trial from ax.core.types import TParamValue +from ax.exceptions.core import UnsupportedError +from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from ax.utils.common.typeutils import checked_cast from numpy import ndarray @@ -165,3 +167,31 @@ def poll_trial_status( self, trials: Iterable[BaseTrial] ) -> dict[TrialStatus, set[int]]: return {TrialStatus.COMPLETED: {t.index for t in trials}} + + @classmethod + # pyre-fixme [2]: Parameter `obj` must have a type other than `Any`` + def serialize_init_args(cls, obj: Any) -> dict[str, Any]: + """ + It is tricky to use SerializationMixin with instances that have Ax + objects as attributes, as BenchmarkRunners do. Therefore, serialization + is not supported. + """ + raise UnsupportedError( + "serialize_init_args is not a supported method for BenchmarkRunners." + ) + + @classmethod + def deserialize_init_args( + cls, + args: dict[str, Any], + decoder_registry: TDecoderRegistry | None = None, + class_decoder_registry: TClassDecoderRegistry | None = None, + ) -> dict[str, Any]: + """ + It is tricky to use SerializationMixin with instances that have Ax + objects as attributes, as BenchmarkRunners do. Therefore, serialization + is not supported. + """ + raise UnsupportedError( + "deserialize_init_args is not a supported method for BenchmarkRunners." + ) diff --git a/ax/benchmark/runners/botorch_test.py b/ax/benchmark/runners/botorch_test.py index 9c112333b0c..17d5d5b5398 100644 --- a/ax/benchmark/runners/botorch_test.py +++ b/ax/benchmark/runners/botorch_test.py @@ -5,7 +5,6 @@ # pyre-strict -import importlib from abc import ABC, abstractmethod from collections.abc import Mapping from dataclasses import dataclass @@ -17,10 +16,8 @@ from ax.core.types import TParamValue from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker -from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem from botorch.utils.transforms import normalize, unnormalize -from pyre_extensions import assert_is_instance from torch import Tensor @@ -165,42 +162,6 @@ def get_noise_stds(self) -> None | float | dict[str, float]: return noise_std_dict - @classmethod - # pyre-fixme [2]: Parameter `obj` must have a type other than `Any`` - def serialize_init_args(cls, obj: Any) -> dict[str, Any]: - """Serialize the properties needed to initialize the runner. - Used for storage. - """ - runner = assert_is_instance(obj, cls) - - return { - "test_problem_module": runner._test_problem_class.__module__, - "test_problem_class_name": runner._test_problem_class.__name__, - "test_problem_kwargs": runner._test_problem_kwargs, - "outcome_names": runner.outcome_names, - "modified_bounds": runner._modified_bounds, - } - - @classmethod - def deserialize_init_args( - cls, - args: dict[str, Any], - decoder_registry: TDecoderRegistry | None = None, - class_decoder_registry: TClassDecoderRegistry | None = None, - ) -> dict[str, Any]: - """Given a dictionary, deserialize the properties needed to initialize the - runner. Used for storage. - """ - - module = importlib.import_module(args["test_problem_module"]) - - return { - "test_problem_class": getattr(module, args["test_problem_class_name"]), - "test_problem_kwargs": args["test_problem_kwargs"], - "outcome_names": args["outcome_names"], - "modified_bounds": args["modified_bounds"], - } - class BotorchTestProblemRunner(SyntheticProblemRunner): """ diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index d864706efed..603b3175625 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -5,7 +5,6 @@ # pyre-strict -import warnings from collections.abc import Callable, Mapping from typing import Any @@ -17,7 +16,6 @@ from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker -from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry from botorch.utils.datasets import SupervisedDataset from pyre_extensions import assert_is_instance, none_throws from torch import Tensor @@ -132,33 +130,6 @@ def run(self, trial: BaseTrial) -> dict[str, Any]: run_metadata["outcome_names"] = self.outcome_names return run_metadata - @classmethod - # pyre-fixme[2]: Parameter annotation cannot be `Any`. - def serialize_init_args(cls, obj: Any) -> dict[str, Any]: - """Serialize the properties needed to initialize the runner. - Used for storage. - - WARNING: Because of issues with consistently saving and loading BoTorch and - GPyTorch modules the SurrogateRunner cannot be serialized at this time. - At load time the runner will be replaced with a SyntheticRunner. - """ - warnings.warn( - "Because of issues with consistently saving and loading BoTorch and " - f"GPyTorch modules, {cls.__name__} cannot be serialized at this time. " - "At load time the runner will be replaced with a SyntheticRunner.", - stacklevel=3, - ) - return {} - - @classmethod - def deserialize_init_args( - cls, - args: dict[str, Any], - decoder_registry: TDecoderRegistry | None = None, - class_decoder_registry: TClassDecoderRegistry | None = None, - ) -> dict[str, Any]: - return {} - @property def is_noiseless(self) -> bool: if self.noise_stds is None: diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index aa9cb983b06..6aea2348777 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -20,6 +20,7 @@ from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.trial import Trial +from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem @@ -174,30 +175,14 @@ def test_synthetic_runner(self) -> None: ) with self.subTest(f"test `serialize_init_args()`, {test_description}"): - serialize_init_args = runner_cls.serialize_init_args(obj=runner) - self.assertEqual( - serialize_init_args, - { - "test_problem_module": runner._test_problem_class.__module__, - "test_problem_class_name": runner._test_problem_class.__name__, - "test_problem_kwargs": runner._test_problem_kwargs, - "outcome_names": runner.outcome_names, - "modified_bounds": runner._modified_bounds, - }, - ) - # test deserialize args - deserialize_init_args = runner_cls.deserialize_init_args( - serialize_init_args - ) - self.assertEqual( - deserialize_init_args, - { - "test_problem_class": test_problem_class, - "test_problem_kwargs": test_problem_kwargs, - "outcome_names": outcome_names, - "modified_bounds": modified_bounds, - }, - ) + with self.assertRaisesRegex( + UnsupportedError, "serialize_init_args is not a supported method" + ): + runner_cls.serialize_init_args(obj=runner) + with self.assertRaisesRegex( + UnsupportedError, "deserialize_init_args is not a supported method" + ): + runner_cls.deserialize_init_args({}) def test_botorch_test_problem_runner_heterogeneous_noise(self) -> None: runner = BotorchTestProblemRunner( diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index d5f1611f741..0cdb49c32b6 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -13,14 +13,7 @@ import torch from ax.benchmark.benchmark_method import BenchmarkMethod from ax.benchmark.benchmark_metric import BenchmarkMetric -from ax.benchmark.benchmark_problem import BenchmarkProblem from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult -from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem -from ax.benchmark.runners.botorch_test import ( - BotorchTestProblemRunner, - ParamBasedTestProblemRunner, -) -from ax.benchmark.runners.surrogate import SurrogateRunner from ax.core import Experiment, ObservationFeatures from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose @@ -190,7 +183,6 @@ BatchTrial: batch_to_dict, BenchmarkMetric: metric_to_dict, BoTorchModel: botorch_model_to_dict, - BotorchTestProblemRunner: runner_to_dict, BraninMetric: metric_to_dict, BraninTimestampMapMetric: metric_to_dict, ChainedInputTransform: botorch_component_to_dict, @@ -236,7 +228,6 @@ OrEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict, OrderConstraint: order_parameter_constraint_to_dict, OutcomeConstraint: outcome_constraint_to_dict, - ParamBasedTestProblemRunner: runner_to_dict, ParameterConstraint: parameter_constraint_to_dict, ParameterDistribution: parameter_distribution_to_dict, pathlib.Path: pathlib_to_dict, @@ -257,7 +248,6 @@ SobolQMCNormalSampler: botorch_component_to_dict, SumConstraint: sum_parameter_constraint_to_dict, Surrogate: surrogate_to_dict, - SurrogateRunner: runner_to_dict, SyntheticRunner: runner_to_dict, ThresholdEarlyStoppingStrategy: threshold_early_stopping_strategy_to_dict, Trial: trial_to_dict, @@ -296,10 +286,8 @@ "BatchTrial": BatchTrial, "BenchmarkMethod": BenchmarkMethod, "BenchmarkMetric": BenchmarkMetric, - "BenchmarkProblem": BenchmarkProblem, "BenchmarkResult": BenchmarkResult, "BoTorchModel": BoTorchModel, - "BotorchTestProblemRunner": BotorchTestProblemRunner, "BraninMetric": BraninMetric, "BraninTimestampMapMetric": BraninTimestampMapMetric, "ChainedInputTransform": ChainedInputTransform, @@ -344,7 +332,6 @@ "ModelRegistryBase": ModelRegistryBase, "ModelSpec": ModelSpec, "MultiObjective": MultiObjective, - "MultiObjectiveBenchmarkProblem": BenchmarkProblem, # backward compatibility "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, "MultiTypeExperiment": MultiTypeExperiment, "NegativeBraninMetric": NegativeBraninMetric, @@ -357,7 +344,6 @@ "OrEarlyStoppingStrategy": OrEarlyStoppingStrategy, "OrderConstraint": OrderConstraint, "OutcomeConstraint": OutcomeConstraint, - "ParamBasedTestProblemRunner": ParamBasedTestProblemRunner, "ParameterConstraint": ParameterConstraint, "ParameterConstraintType": ParameterConstraintType, "ParameterDistribution": ParameterDistribution, @@ -369,7 +355,6 @@ "PurePosixPath": pathlib_from_json, "PureWindowsPath": pathlib_from_json, "PercentileEarlyStoppingStrategy": PercentileEarlyStoppingStrategy, - "PyTorchCNNTorchvisionParamBasedProblem": PyTorchCNNTorchvisionParamBasedProblem, "RangeParameter": RangeParameter, "ReductionCriterion": ReductionCriterion, "RiskMeasure": RiskMeasure, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index fade9cab58f..beb32779315 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -10,12 +10,9 @@ import os import tempfile from functools import partial -from unittest.mock import patch import numpy as np import torch -from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem -from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem from ax.core.metric import Metric from ax.core.objective import Objective from ax.core.runner import Runner @@ -49,10 +46,7 @@ from ax.utils.testing.benchmark_stubs import ( get_aggregated_benchmark_result, get_benchmark_result, - get_multi_objective_benchmark_problem, - get_single_objective_benchmark_problem, get_sobol_gpei_benchmark_method, - TestDataset, ) from ax.utils.testing.core_stubs import ( get_abandoned_arm, @@ -143,15 +137,10 @@ ("AuxiliaryExperiment", get_auxiliary_experiment), ("BatchTrial", get_batch_trial), ("BenchmarkMethod", get_sobol_gpei_benchmark_method), - ("BenchmarkProblem", get_single_objective_benchmark_problem), ("BenchmarkResult", get_benchmark_result), ("BoTorchModel", get_botorch_model), ("BoTorchModel", get_botorch_model_with_default_acquisition_class), ("BoTorchModel", get_botorch_model_with_surrogate_specs), - ( - "BoTorchTestProblemRunner", - lambda: get_single_objective_benchmark_problem().runner, - ), ("BraninMetric", get_branin_metric), ("ChainedInputTransform", get_chained_input_transform), ("ChoiceParameter", get_choice_parameter), @@ -234,7 +223,6 @@ ("MapKeyInfo", get_map_key_info), ("Metric", get_metric), ("MultiObjective", get_multi_objective), - ("MultiObjectiveBenchmarkProblem", get_multi_objective_benchmark_problem), ("MultiObjectiveOptimizationConfig", get_multi_objective_optimization_config), ("MultiTypeExperiment", get_multi_type_experiment), ("ObservationFeatures", get_observation_features), @@ -245,13 +233,11 @@ ("OrderConstraint", get_order_constraint), ("OutcomeConstraint", get_outcome_constraint), ("Path", get_pathlib_path), - ("Jenatton", get_jenatton_benchmark_problem), ("PercentileEarlyStoppingStrategy", get_percentile_early_stopping_strategy), ( "PercentileEarlyStoppingStrategy", get_percentile_early_stopping_strategy_with_non_objective_metric_name, ), - ("ParamBasedTestProblemRunner", lambda: get_jenatton_benchmark_problem().runner), ("ParameterConstraint", get_parameter_constraint), ("ParameterDistribution", get_parameter_distribution), ("RangeParameter", get_range_parameter), @@ -261,7 +247,6 @@ ("SchedulerOptions", get_default_scheduler_options), ("SchedulerOptions", get_scheduler_options_batch_trial), ("SearchSpace", get_search_space), - ("SingleObjectiveBenchmarkProblem", get_single_objective_benchmark_problem), ("SumConstraint", get_sum_constraint1), ("SumConstraint", get_sum_constraint2), ("Surrogate", get_surrogate), @@ -442,24 +427,6 @@ def __post_init__(self, doesnt_serialize: None) -> None: self.assertEqual(recovered.not_a_field, 1) self.assertEqual(obj, recovered) - def test_EncodeDecode_torchvision_problem(self) -> None: - registry_path = "ax.benchmark.problems.hpo.torchvision._REGISTRY" - mock_registry = {"MNIST": TestDataset} - with patch.dict(registry_path, mock_registry): - test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST") - - self.assertIsNotNone(test_problem.train_loader) - self.assertIsNotNone(test_problem.test_loader) - - as_json = object_to_json(obj=test_problem) - self.assertNotIn("train_loader", as_json) - - with patch.dict(registry_path, mock_registry): - recovered = object_from_json(as_json) - - self.assertIsNotNone(recovered.train_loader) - self.assertEqual(test_problem, recovered) - def test_EncodeDecodeTorchTensor(self) -> None: x = torch.tensor( [[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu") diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 2cec5b7b7f4..db89326219c 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -205,7 +205,6 @@ def get_benchmark_result() -> BenchmarkResult: name="test_benchmarking_experiment", search_space=problem.search_space, optimization_config=problem.optimization_config, - runner=problem.runner, is_test=True, ), inference_trace=np.ones(4),