Skip to content

Commit

Permalink
Do not support serializing BenchmarkRunners (#2889)
Browse files Browse the repository at this point in the history
Summary:

Context: As of D64263610, serializing `BenchmarkProblem`s and `BenchmarkRunner`s is not needed, and in the future, we will make changes that are difficult to make work well with `SerializationMixin`. Therefore, this diff revokes support for serializating `BenchmarkProblem`s and `BenchmarkRunner`s.

*But what if we want to serialize these?* IMO, we should use the benchmark problem registry to reconstruct these, which will be much more ergonomic. You might worry that it's expensive to reconstruct problems from scratch, but there was never any computational efficiency benefit from serialization rather than rebuilding from scratch, because the most expensive components (surrogates and data) could not be serialized anyway.

This diff:
* Strips the `BenchmarkRunner` from an `Experiment` before saving the experiment -- that probably should not have been saved anyway
* Makes all `BenchmarkRunners` throw exceptions if their `serialize_int_args` or `deserialize_init_args` methods are used
* Removes `SyntheticProblemRunner`'s serialization methods
* Does the same for `SurrogateRunner` (which only had fake serialization anyway)

Reviewed By: Balandat

Differential Revision: D64272347
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 16, 2024
1 parent dacedfc commit 4fb12b9
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 141 deletions.
3 changes: 3 additions & 0 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def benchmark_replication(
score_trace = np.full(len(optimization_trace), np.nan)

fit_time, gen_time = get_model_times(experiment=experiment)
# Strip runner from experiment before returning, so that the experiment can
# be serialized (the runner can't be)
experiment.runner = None

return BenchmarkResult(
name=scheduler.experiment.name,
Expand Down
30 changes: 30 additions & 0 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from ax.core.search_space import SearchSpaceDigest
from ax.core.trial import Trial
from ax.core.types import TParamValue
from ax.exceptions.core import UnsupportedError
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry

from ax.utils.common.typeutils import checked_cast
from numpy import ndarray
Expand Down Expand Up @@ -165,3 +167,31 @@ def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> dict[TrialStatus, set[int]]:
return {TrialStatus.COMPLETED: {t.index for t in trials}}

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""
It is tricky to use SerializationMixin with instances that have Ax
objects as attributes, as BenchmarkRunners do. Therefore, serialization
is not supported.
"""
raise UnsupportedError(
"serialize_init_args is not a supported method for BenchmarkRunners."
)

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
"""
It is tricky to use SerializationMixin with instances that have Ax
objects as attributes, as BenchmarkRunners do. Therefore, serialization
is not supported.
"""
raise UnsupportedError(
"deserialize_init_args is not a supported method for BenchmarkRunners."
)
39 changes: 0 additions & 39 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

import importlib
from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
Expand All @@ -17,10 +16,8 @@
from ax.core.types import TParamValue
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
from pyre_extensions import assert_is_instance
from torch import Tensor


Expand Down Expand Up @@ -165,42 +162,6 @@ def get_noise_stds(self) -> None | float | dict[str, float]:

return noise_std_dict

@classmethod
# pyre-fixme [2]: Parameter `obj` must have a type other than `Any``
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.
"""
runner = assert_is_instance(obj, cls)

return {
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner.outcome_names,
"modified_bounds": runner._modified_bounds,
}

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
"""Given a dictionary, deserialize the properties needed to initialize the
runner. Used for storage.
"""

module = importlib.import_module(args["test_problem_module"])

return {
"test_problem_class": getattr(module, args["test_problem_class_name"]),
"test_problem_kwargs": args["test_problem_kwargs"],
"outcome_names": args["outcome_names"],
"modified_bounds": args["modified_bounds"],
}


class BotorchTestProblemRunner(SyntheticProblemRunner):
"""
Expand Down
29 changes: 0 additions & 29 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

import warnings
from collections.abc import Callable, Mapping
from typing import Any

Expand All @@ -17,7 +16,6 @@
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry
from botorch.utils.datasets import SupervisedDataset
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
Expand Down Expand Up @@ -132,33 +130,6 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
run_metadata["outcome_names"] = self.outcome_names
return run_metadata

@classmethod
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
def serialize_init_args(cls, obj: Any) -> dict[str, Any]:
"""Serialize the properties needed to initialize the runner.
Used for storage.
WARNING: Because of issues with consistently saving and loading BoTorch and
GPyTorch modules the SurrogateRunner cannot be serialized at this time.
At load time the runner will be replaced with a SyntheticRunner.
"""
warnings.warn(
"Because of issues with consistently saving and loading BoTorch and "
f"GPyTorch modules, {cls.__name__} cannot be serialized at this time. "
"At load time the runner will be replaced with a SyntheticRunner.",
stacklevel=3,
)
return {}

@classmethod
def deserialize_init_args(
cls,
args: dict[str, Any],
decoder_registry: TDecoderRegistry | None = None,
class_decoder_registry: TClassDecoderRegistry | None = None,
) -> dict[str, Any]:
return {}

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
Expand Down
33 changes: 9 additions & 24 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem
Expand Down Expand Up @@ -174,30 +175,14 @@ def test_synthetic_runner(self) -> None:
)

with self.subTest(f"test `serialize_init_args()`, {test_description}"):
serialize_init_args = runner_cls.serialize_init_args(obj=runner)
self.assertEqual(
serialize_init_args,
{
"test_problem_module": runner._test_problem_class.__module__,
"test_problem_class_name": runner._test_problem_class.__name__,
"test_problem_kwargs": runner._test_problem_kwargs,
"outcome_names": runner.outcome_names,
"modified_bounds": runner._modified_bounds,
},
)
# test deserialize args
deserialize_init_args = runner_cls.deserialize_init_args(
serialize_init_args
)
self.assertEqual(
deserialize_init_args,
{
"test_problem_class": test_problem_class,
"test_problem_kwargs": test_problem_kwargs,
"outcome_names": outcome_names,
"modified_bounds": modified_bounds,
},
)
with self.assertRaisesRegex(
UnsupportedError, "serialize_init_args is not a supported method"
):
runner_cls.serialize_init_args(obj=runner)
with self.assertRaisesRegex(
UnsupportedError, "deserialize_init_args is not a supported method"
):
runner_cls.deserialize_init_args({})

def test_botorch_test_problem_runner_heterogeneous_noise(self) -> None:
runner = BotorchTestProblemRunner(
Expand Down
15 changes: 0 additions & 15 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem
from ax.benchmark.runners.botorch_test import (
BotorchTestProblemRunner,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.core import Experiment, ObservationFeatures
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
Expand Down Expand Up @@ -190,7 +183,6 @@
BatchTrial: batch_to_dict,
BenchmarkMetric: metric_to_dict,
BoTorchModel: botorch_model_to_dict,
BotorchTestProblemRunner: runner_to_dict,
BraninMetric: metric_to_dict,
BraninTimestampMapMetric: metric_to_dict,
ChainedInputTransform: botorch_component_to_dict,
Expand Down Expand Up @@ -236,7 +228,6 @@
OrEarlyStoppingStrategy: logical_early_stopping_strategy_to_dict,
OrderConstraint: order_parameter_constraint_to_dict,
OutcomeConstraint: outcome_constraint_to_dict,
ParamBasedTestProblemRunner: runner_to_dict,
ParameterConstraint: parameter_constraint_to_dict,
ParameterDistribution: parameter_distribution_to_dict,
pathlib.Path: pathlib_to_dict,
Expand All @@ -257,7 +248,6 @@
SobolQMCNormalSampler: botorch_component_to_dict,
SumConstraint: sum_parameter_constraint_to_dict,
Surrogate: surrogate_to_dict,
SurrogateRunner: runner_to_dict,
SyntheticRunner: runner_to_dict,
ThresholdEarlyStoppingStrategy: threshold_early_stopping_strategy_to_dict,
Trial: trial_to_dict,
Expand Down Expand Up @@ -296,10 +286,8 @@
"BatchTrial": BatchTrial,
"BenchmarkMethod": BenchmarkMethod,
"BenchmarkMetric": BenchmarkMetric,
"BenchmarkProblem": BenchmarkProblem,
"BenchmarkResult": BenchmarkResult,
"BoTorchModel": BoTorchModel,
"BotorchTestProblemRunner": BotorchTestProblemRunner,
"BraninMetric": BraninMetric,
"BraninTimestampMapMetric": BraninTimestampMapMetric,
"ChainedInputTransform": ChainedInputTransform,
Expand Down Expand Up @@ -344,7 +332,6 @@
"ModelRegistryBase": ModelRegistryBase,
"ModelSpec": ModelSpec,
"MultiObjective": MultiObjective,
"MultiObjectiveBenchmarkProblem": BenchmarkProblem, # backward compatibility
"MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig,
"MultiTypeExperiment": MultiTypeExperiment,
"NegativeBraninMetric": NegativeBraninMetric,
Expand All @@ -357,7 +344,6 @@
"OrEarlyStoppingStrategy": OrEarlyStoppingStrategy,
"OrderConstraint": OrderConstraint,
"OutcomeConstraint": OutcomeConstraint,
"ParamBasedTestProblemRunner": ParamBasedTestProblemRunner,
"ParameterConstraint": ParameterConstraint,
"ParameterConstraintType": ParameterConstraintType,
"ParameterDistribution": ParameterDistribution,
Expand All @@ -369,7 +355,6 @@
"PurePosixPath": pathlib_from_json,
"PureWindowsPath": pathlib_from_json,
"PercentileEarlyStoppingStrategy": PercentileEarlyStoppingStrategy,
"PyTorchCNNTorchvisionParamBasedProblem": PyTorchCNNTorchvisionParamBasedProblem,
"RangeParameter": RangeParameter,
"ReductionCriterion": ReductionCriterion,
"RiskMeasure": RiskMeasure,
Expand Down
33 changes: 0 additions & 33 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@
import os
import tempfile
from functools import partial
from unittest.mock import patch

import numpy as np
import torch
from ax.benchmark.problems.hpo.torchvision import PyTorchCNNTorchvisionParamBasedProblem
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from ax.core.metric import Metric
from ax.core.objective import Objective
from ax.core.runner import Runner
Expand Down Expand Up @@ -49,10 +46,7 @@
from ax.utils.testing.benchmark_stubs import (
get_aggregated_benchmark_result,
get_benchmark_result,
get_multi_objective_benchmark_problem,
get_single_objective_benchmark_problem,
get_sobol_gpei_benchmark_method,
TestDataset,
)
from ax.utils.testing.core_stubs import (
get_abandoned_arm,
Expand Down Expand Up @@ -143,15 +137,10 @@
("AuxiliaryExperiment", get_auxiliary_experiment),
("BatchTrial", get_batch_trial),
("BenchmarkMethod", get_sobol_gpei_benchmark_method),
("BenchmarkProblem", get_single_objective_benchmark_problem),
("BenchmarkResult", get_benchmark_result),
("BoTorchModel", get_botorch_model),
("BoTorchModel", get_botorch_model_with_default_acquisition_class),
("BoTorchModel", get_botorch_model_with_surrogate_specs),
(
"BoTorchTestProblemRunner",
lambda: get_single_objective_benchmark_problem().runner,
),
("BraninMetric", get_branin_metric),
("ChainedInputTransform", get_chained_input_transform),
("ChoiceParameter", get_choice_parameter),
Expand Down Expand Up @@ -234,7 +223,6 @@
("MapKeyInfo", get_map_key_info),
("Metric", get_metric),
("MultiObjective", get_multi_objective),
("MultiObjectiveBenchmarkProblem", get_multi_objective_benchmark_problem),
("MultiObjectiveOptimizationConfig", get_multi_objective_optimization_config),
("MultiTypeExperiment", get_multi_type_experiment),
("ObservationFeatures", get_observation_features),
Expand All @@ -245,13 +233,11 @@
("OrderConstraint", get_order_constraint),
("OutcomeConstraint", get_outcome_constraint),
("Path", get_pathlib_path),
("Jenatton", get_jenatton_benchmark_problem),
("PercentileEarlyStoppingStrategy", get_percentile_early_stopping_strategy),
(
"PercentileEarlyStoppingStrategy",
get_percentile_early_stopping_strategy_with_non_objective_metric_name,
),
("ParamBasedTestProblemRunner", lambda: get_jenatton_benchmark_problem().runner),
("ParameterConstraint", get_parameter_constraint),
("ParameterDistribution", get_parameter_distribution),
("RangeParameter", get_range_parameter),
Expand All @@ -261,7 +247,6 @@
("SchedulerOptions", get_default_scheduler_options),
("SchedulerOptions", get_scheduler_options_batch_trial),
("SearchSpace", get_search_space),
("SingleObjectiveBenchmarkProblem", get_single_objective_benchmark_problem),
("SumConstraint", get_sum_constraint1),
("SumConstraint", get_sum_constraint2),
("Surrogate", get_surrogate),
Expand Down Expand Up @@ -442,24 +427,6 @@ def __post_init__(self, doesnt_serialize: None) -> None:
self.assertEqual(recovered.not_a_field, 1)
self.assertEqual(obj, recovered)

def test_EncodeDecode_torchvision_problem(self) -> None:
registry_path = "ax.benchmark.problems.hpo.torchvision._REGISTRY"
mock_registry = {"MNIST": TestDataset}
with patch.dict(registry_path, mock_registry):
test_problem = PyTorchCNNTorchvisionParamBasedProblem(name="MNIST")

self.assertIsNotNone(test_problem.train_loader)
self.assertIsNotNone(test_problem.test_loader)

as_json = object_to_json(obj=test_problem)
self.assertNotIn("train_loader", as_json)

with patch.dict(registry_path, mock_registry):
recovered = object_from_json(as_json)

self.assertIsNotNone(recovered.train_loader)
self.assertEqual(test_problem, recovered)

def test_EncodeDecodeTorchTensor(self) -> None:
x = torch.tensor(
[[1.0, 2.0], [3.0, 4.0]], dtype=torch.float64, device=torch.device("cpu")
Expand Down
1 change: 0 additions & 1 deletion ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ def get_benchmark_result() -> BenchmarkResult:
name="test_benchmarking_experiment",
search_space=problem.search_space,
optimization_config=problem.optimization_config,
runner=problem.runner,
is_test=True,
),
inference_trace=np.ones(4),
Expand Down

0 comments on commit 4fb12b9

Please sign in to comment.