Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not support serializing BenchmarkRunners #2889

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def benchmark_replication(
score_trace = np.full(len(optimization_trace), np.nan)

fit_time, gen_time = get_model_times(experiment=experiment)
# Strip runner from experiment before returning, so that the experiment can
# be serialized (the runner can't be)
experiment.runner = None

return BenchmarkResult(
name=scheduler.experiment.name,
Expand Down
30 changes: 30 additions & 0 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from ax.core.search_space import SearchSpaceDigest
from ax.core.trial import Trial
from ax.core.types import TParamValue
from ax.exceptions.core import UnsupportedError
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry

from ax.utils.common.typeutils import checked_cast
from numpy import ndarray
Expand Down Expand Up @@ -165,3 +167,31 @@ def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> dict[TrialStatus, set[int]]:
return {TrialStatus.COMPLETED: {t.index for t in trials}}

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""
It is tricky to use SerializationMixin with instances that have Ax
objects as attributes, as BenchmarkRunners do. Therefore, serialization
is not supported.
"""
raise UnsupportedError(
"serialize_init_args is not a supported method for BenchmarkRunners."
)

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
"""
It is tricky to use SerializationMixin with instances that have Ax
objects as attributes, as BenchmarkRunners do. Therefore, serialization
is not supported.
"""
raise UnsupportedError(
"deserialize_init_args is not a supported method for BenchmarkRunners."
)
39 changes: 0 additions & 39 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

import importlib
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
Expand All @@ -17,10 +16,8 @@
from ax.core.types import TParamValue
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
from pyre_extensions import assert_is_instance
from torch import Tensor


Expand Down Expand Up @@ -165,42 +162,6 @@ def get_noise_stds(self) -> None | float | dict[str, float]:

return noise_std_dict

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.
"""
runner = assert_is_instance(obj, cls)

return {
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner.outcome_names,
"modified_bounds": runner._modified_bounds,
}

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
runner. Used for storage.
"""

module = importlib.import_module(args["test_problem_module"])

return {
"test_problem_class": getattr(module, args["test_problem_class_name"]),
"test_problem_kwargs": args["test_problem_kwargs"],
"outcome_names": args["outcome_names"],
"modified_bounds": args["modified_bounds"],
}


class BotorchTestProblemRunner(SyntheticProblemRunner):
"""
Expand Down
29 changes: 0 additions & 29 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

import warnings
from collections.abc import Callable, Mapping
from typing import Any

Expand All @@ -17,7 +16,6 @@
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
Expand Down Expand Up @@ -132,33 +130,6 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
run_metadata["outcome_names"] = self.outcome_names
return run_metadata

@classmethod
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.

WARNING: Because of issues with consistently saving and loading BoTorch and
GPyTorch modules the SurrogateRunner cannot be serialized at this time.
At load time the runner will be replaced with a SyntheticRunner.
"""
warnings.warn(
"Because of issues with consistently saving and loading BoTorch and "
f"GPyTorch modules, {cls.__name__} cannot be serialized at this time. "
"At load time the runner will be replaced with a SyntheticRunner.",
stacklevel=3,
)
return {}

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
return {}

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
Expand Down
33 changes: 9 additions & 24 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem
Expand Down Expand Up @@ -174,30 +175,14 @@ def test_synthetic_runner(self) -> None:
)

with self.subTest(f"test `serialize_init_args()`, {test_description}"):
serialize_init_args = runner_cls.serialize_init_args(obj=runner)
self.assertEqual(
serialize_init_args,
{
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner.outcome_names,
"modified_bounds": runner._modified_bounds,
},
)
# test deserialize args
deserialize_init_args = runner_cls.deserialize_init_args(
serialize_init_args
)
self.assertEqual(
deserialize_init_args,
{
"test_problem_class": test_problem_class,
"test_problem_kwargs": test_problem_kwargs,
"outcome_names": outcome_names,
"modified_bounds": modified_bounds,
},
)
with self.assertRaisesRegex(
UnsupportedError, "serialize_init_args is not a supported method"
):
runner_cls.serialize_init_args(obj=runner)
with self.assertRaisesRegex(
UnsupportedError, "deserialize_init_args is not a supported method"
):
runner_cls.deserialize_init_args({})

def test_botorch_test_problem_runner_heterogeneous_noise(self) -> None:
runner = BotorchTestProblemRunner(
Expand Down
15 changes: 0 additions & 15 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem
from ax.benchmark.runners.botorch_test import (
BotorchTestProblemRunner,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
Expand Down Expand Up @@ -190,7 +183,6 @@
BatchTrial: batch_to_dict,
BenchmarkMetric: metric_to_dict,
BoTorchModel: botorch_model_to_dict,
BotorchTestProblemRunner: runner_to_dict,
BraninMetric: metric_to_dict,
BraninTimestampMapMetric: metric_to_dict,
ChainedInputTransform: botorch_component_to_dict,
Expand Down Expand Up @@ -236,7 +228,6 @@
OrEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
OrderConstraint: order_parameter_constraint_to_dict,
OutcomeConstraint: outcome_constraint_to_dict,
ParamBasedTestProblemRunner: runner_to_dict,
ParameterConstraint: parameter_constraint_to_dict,
ParameterDistribution: parameter_distribution_to_dict,
pathlib.Path: pathlib_to_dict,
Expand All @@ -257,7 +248,6 @@
SobolQMCNormalSampler: botorch_component_to_dict,
SumConstraint: sum_parameter_constraint_to_dict,
Surrogate: surrogate_to_dict,
SurrogateRunner: runner_to_dict,
SyntheticRunner: runner_to_dict,
ThresholdEarlyStoppingStrategy: threshold_early_stopping_strategy_to_dict,
Trial: trial_to_dict,
Expand Down Expand Up @@ -296,10 +286,8 @@
"BatchTrial": BatchTrial,
"BenchmarkMethod": BenchmarkMethod,
"BenchmarkMetric": BenchmarkMetric,
"BenchmarkProblem": BenchmarkProblem,
"BenchmarkResult": BenchmarkResult,
"BoTorchModel": BoTorchModel,
"BotorchTestProblemRunner": BotorchTestProblemRunner,
"BraninMetric": BraninMetric,
"BraninTimestampMapMetric": BraninTimestampMapMetric,
"ChainedInputTransform": ChainedInputTransform,
Expand Down Expand Up @@ -344,7 +332,6 @@
"ModelRegistryBase": ModelRegistryBase,
"ModelSpec": ModelSpec,
"MultiObjective": MultiObjective,
"MultiObjectiveBenchmarkProblem": BenchmarkProblem, # backward compatibility
"MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig,
"MultiTypeExperiment": MultiTypeExperiment,
"NegativeBraninMetric": NegativeBraninMetric,
Expand All @@ -357,7 +344,6 @@
"OrEarlyStoppingStrategy": OrEarlyStoppingStrategy,
"OrderConstraint": OrderConstraint,
"OutcomeConstraint": OutcomeConstraint,
"ParamBasedTestProblemRunner": ParamBasedTestProblemRunner,
"ParameterConstraint": ParameterConstraint,
"ParameterConstraintType": ParameterConstraintType,
"ParameterDistribution": ParameterDistribution,
Expand All @@ -369,7 +355,6 @@
"PurePosixPath": pathlib_from_json,
"PureWindowsPath": pathlib_from_json,
"PercentileEarlyStoppingStrategy": PercentileEarlyStoppingStrategy,
"PyTorchCNNTorchvisionParamBasedProblem": PyTorchCNNTorchvisionParamBasedProblem,
"RangeParameter": RangeParameter,
"ReductionCriterion": ReductionCriterion,
"RiskMeasure": RiskMeasure,
Expand Down
33 changes: 0 additions & 33 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
import os
import tempfile
from functools import partial
from unittest.mock import patch

import numpy as np
import torch
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.runner import Runner
Expand Down Expand Up @@ -49,10 +46,7 @@
from ax.utils.testing.benchmark_stubs import (
get_aggregated_benchmark_result,
get_benchmark_result,
get_multi_objective_benchmark_problem,
get_single_objective_benchmark_problem,
get_sobol_gpei_benchmark_method,
TestDataset,
)
from ax.utils.testing.core_stubs import (
get_abandoned_arm,
Expand Down Expand Up @@ -143,15 +137,10 @@
("AuxiliaryExperiment", get_auxiliary_experiment),
("BatchTrial", get_batch_trial),
("BenchmarkMethod", get_sobol_gpei_benchmark_method),
("BenchmarkProblem", get_single_objective_benchmark_problem),
("BenchmarkResult", get_benchmark_result),
("BoTorchModel", get_botorch_model),
("BoTorchModel", get_botorch_model_with_default_acquisition_class),
("BoTorchModel", get_botorch_model_with_surrogate_specs),
(
"BoTorchTestProblemRunner",
lambda: get_single_objective_benchmark_problem().runner,
),
("BraninMetric", get_branin_metric),
("ChainedInputTransform", get_chained_input_transform),
("ChoiceParameter", get_choice_parameter),
Expand Down Expand Up @@ -234,7 +223,6 @@
("MapKeyInfo", get_map_key_info),
("Metric", get_metric),
("MultiObjective", get_multi_objective),
("MultiObjectiveBenchmarkProblem", get_multi_objective_benchmark_problem),
("MultiObjectiveOptimizationConfig", get_multi_objective_optimization_config),
("MultiTypeExperiment", get_multi_type_experiment),
("ObservationFeatures", get_observation_features),
Expand All @@ -245,13 +233,11 @@
("OrderConstraint", get_order_constraint),
("OutcomeConstraint", get_outcome_constraint),
("Path", get_pathlib_path),
("Jenatton", get_jenatton_benchmark_problem),
("PercentileEarlyStoppingStrategy", get_percentile_early_stopping_strategy),
(
"PercentileEarlyStoppingStrategy",
get_percentile_early_stopping_strategy_with_non_objective_metric_name,
),
("ParamBasedTestProblemRunner", lambda: get_jenatton_benchmark_problem().runner),
("ParameterConstraint", get_parameter_constraint),
("ParameterDistribution", get_parameter_distribution),
("RangeParameter", get_range_parameter),
Expand All @@ -261,7 +247,6 @@
("SchedulerOptions", get_default_scheduler_options),
("SchedulerOptions", get_scheduler_options_batch_trial),
("SearchSpace", get_search_space),
("SingleObjectiveBenchmarkProblem", get_single_objective_benchmark_problem),
("SumConstraint", get_sum_constraint1),
("SumConstraint", get_sum_constraint2),
("Surrogate", get_surrogate),
Expand Down Expand Up @@ -442,24 +427,6 @@ def __post_init__(self, doesnt_serialize: None) -> None:
self.assertEqual(recovered.not_a_field, 1)
self.assertEqual(obj, recovered)

def test_EncodeDecode_torchvision_problem(self) -> None:
registry_path = "ax.benchmark.problems.hpo.torchvision._REGISTRY"
mock_registry = {"MNIST": TestDataset}
with patch.dict(registry_path, mock_registry):
test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST")

self.assertIsNotNone(test_problem.train_loader)
self.assertIsNotNone(test_problem.test_loader)

as_json = object_to_json(obj=test_problem)
self.assertNotIn("train_loader", as_json)

with patch.dict(registry_path, mock_registry):
recovered = object_from_json(as_json)

self.assertIsNotNone(recovered.train_loader)
self.assertEqual(test_problem, recovered)

def test_EncodeDecodeTorchTensor(self) -> None:
x = torch.tensor(
[[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu")
Expand Down
1 change: 0 additions & 1 deletion ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def get_benchmark_result() -> BenchmarkResult:
name="test_benchmarking_experiment",
search_space=problem.search_space,
optimization_config=problem.optimization_config,
runner=problem.runner,
is_test=True,
),
inference_trace=np.ones(4),
Expand Down