From 027cfbadabe6ac74b0946b62150352ad26027286 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Fri, 25 Oct 2024 13:13:04 -0700 Subject: [PATCH] Get rid of SurrogateRunner and SurrogateBenchmarkProblem (#2954) Summary: **Context**: This diff cuts over over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity. **Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming. The name changes I hope to make: * ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner * ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate. * BoTorchTestProblem -> BoTorchTestFunction **Changes in this diff**: * Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates * Removes SurrogateRunner. There is some loss of (unused) functionality here; dict-valued 'noise_std' is no longer supported. I expect we will eventually want that, but it doesn't need to be done now. * Changed `predict_for_tensor` utility function to use a `SurrogateTestFunction` rather than a `SurrogateRunner` * Removed `SurrogateRunner` from the json_store registry. There is no point in keeping it around for backward compatibility since it was just deserializing to a dummy `SyntheticRunner` anyway. Differential Revision: D64855010 --- ax/benchmark/problems/surrogate.py | 33 ----- ax/benchmark/runners/surrogate.py | 134 ------------------ .../tests/runners/test_surrogate_runner.py | 90 +----------- ax/storage/json_store/registry.py | 2 - ax/utils/testing/benchmark_stubs.py | 49 ++----- 5 files changed, 11 insertions(+), 297 deletions(-) delete mode 100644 ax/benchmark/problems/surrogate.py diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py deleted file mode 100644 index 6c188540801..00000000000 --- a/ax/benchmark/problems/surrogate.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict -""" -Benchmark problem based on surrogate. - -This problem class might appear to function identically to its non-surrogate -counterpart, `BenchmarkProblem`, aside from the restriction that its runners is -of type `SurrogateRunner`. However, it is treated specially within JSON storage -because surrogates cannot be easily serialized. -""" - -from dataclasses import dataclass, field - -from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.runners.surrogate import SurrogateRunner - - -@dataclass(kw_only=True) -class SurrogateBenchmarkProblem(BenchmarkProblem): - """ - Benchmark problem whose `runner` is a `SurrogateRunner`. - - `SurrogateRunner` allows for the surrogate to be constructed lazily and for - datasets to be downloaded lazily. - - For argument descriptions, see `BenchmarkProblem`. - """ - - runner: SurrogateRunner = field(repr=False) diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 3a8bf83ba4c..c7b27171f00 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -100,137 +100,3 @@ def __eq__(self, other: Base) -> bool: # Don't check surrogate, datasets, or callable return self.name == other.name - - -@dataclass -class SurrogateRunner(BenchmarkRunner): - """Runner for surrogate benchmark problems. - - Args: - name: The name of the runner. - outcome_names: The names of the outcomes of the Surrogate. - _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use - for generating observations. If `None`, `get_surrogate_and_datasets` - must not be None and will be used to generate the surrogate when it - is needed. - _datasets: Either `None`, or the `SupervisedDataset`s used to fit - the surrogate model. If `None`, `get_surrogate_and_datasets` must - not be None and will be used to generate the datasets when they are - needed. - noise_stds: Noise standard deviations to add to the surrogate output(s). - If a single float is provided, noise with that standard deviation - is added to all outputs. Alternatively, a dictionary mapping outcome - names to noise standard deviations can be provided to specify different - noise levels for different outputs. - get_surrogate_and_datasets: Function that returns the surrogate and - datasets, to allow for lazy construction. If - `get_surrogate_and_datasets` is not provided, `surrogate` and - `datasets` must be provided, and vice versa. - search_space_digest: Used to get the target task and fidelity at - which the oracle is evaluated. - """ - - name: str - _surrogate: TorchModelBridge | None = None - _datasets: list[SupervisedDataset] | None = None - noise_stds: float | dict[str, float] = 0.0 - get_surrogate_and_datasets: ( - None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] - ) = None - statuses: dict[int, TrialStatus] = field(default_factory=dict) - - def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None: - super().__post_init__(search_space_digest=search_space_digest) - if self.get_surrogate_and_datasets is None and ( - self._surrogate is None or self._datasets is None - ): - raise ValueError( - "If `get_surrogate_and_datasets` is None, `_surrogate` " - "and `_datasets` must not be None, and vice versa." - ) - - def set_surrogate_and_datasets(self) -> None: - self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() - - @property - def surrogate(self) -> TorchModelBridge: - if self._surrogate is None: - self.set_surrogate_and_datasets() - return none_throws(self._surrogate) - - @property - def datasets(self) -> list[SupervisedDataset]: - if self._datasets is None: - self.set_surrogate_and_datasets() - return none_throws(self._datasets) - - def get_noise_stds(self) -> dict[str, float]: - noise_std = self.noise_stds - if isinstance(noise_std, float): - return {name: noise_std for name in self.outcome_names} - return noise_std - - # pyre-fixme[14]: Inconsistent override - def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor: - # We're ignoring the uncertainty predictions of the surrogate model here and - # use the mean predictions as the outcomes (before potentially adding noise) - means, _ = self.surrogate.predict( - # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict - observation_features=[ObservationFeatures(params)] - ) - means = [means[name][0] for name in self.outcome_names] - return torch.tensor( - means, - device=self.surrogate.device, - dtype=self.surrogate.dtype, - ) - - def run(self, trial: BaseTrial) -> dict[str, Any]: - """Run the trial by evaluating its parameterization(s) on the surrogate model. - - Note: This also sets the status of the trial to COMPLETED. - - Args: - trial: The trial to evaluate. - - Returns: - A dictionary with the following keys: - - outcome_names: The names of the metrics being evaluated. - - Ys: A dict mapping arm names to lists of corresponding outcomes, - where the order of the outcomes is the same as in `outcome_names`. - - Ystds: A dict mapping arm names to lists of corresponding outcome - noise standard deviations (possibly nan if the noise level is - unobserved), where the order of the outcomes is the same as in - `outcome_names`. - - Ys_true: A dict mapping arm names to lists of corresponding ground - truth outcomes, where the order of the outcomes is the same as - in `outcome_names`. - """ - self.statuses[trial.index] = TrialStatus.COMPLETED - run_metadata = super().run(trial=trial) - run_metadata["outcome_names"] = self.outcome_names - return run_metadata - - @property - def is_noiseless(self) -> bool: - if self.noise_stds is None: - return True - if isinstance(self.noise_stds, float): - return self.noise_stds == 0.0 - return all( - std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values() - ) - - @equality_typechecker - def __eq__(self, other: Base) -> bool: - if type(other) is not type(self): - return False - - # Don't check surrogate, datasets, or callable - return ( - (self.name == other.name) - and (self.outcome_names == other.outcome_names) - and (self.noise_stds == other.noise_stds) - # pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`. - and (self.search_space_digest == other.search_space_digest) - ) diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index 099d09f2d5e..9ad256f7f90 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,15 +8,10 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction -from ax.core.parameter import ParameterType, RangeParameter -from ax.core.search_space import SearchSpace +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import ( - get_soo_surrogate_legacy, - get_soo_surrogate_test_function, -) +from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function class TestSurrogateTestFunction(TestCase): @@ -84,84 +79,3 @@ def _construct_test_function(name: str) -> SurrogateTestFunction: self.assertEqual(runner_1, runner_1a) self.assertNotEqual(runner_1, runner_2) self.assertNotEqual(runner_1, 1) - - -class TestSurrogateRunner(TestCase): - def setUp(self) -> None: - super().setUp() - self.search_space = SearchSpace( - parameters=[ - RangeParameter("x", ParameterType.FLOAT, 0.0, 5.0), - RangeParameter("y", ParameterType.FLOAT, 1.0, 10.0, log_scale=True), - RangeParameter("z", ParameterType.INT, 1.0, 5.0, log_scale=True), - ] - ) - - def test_surrogate_runner(self) -> None: - # Construct a search space with log-scale parameters. - for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): - with self.subTest(noise_std=noise_std): - surrogate = MagicMock() - mock_mean = torch.tensor([[0.1234]], dtype=torch.double) - surrogate.predict = MagicMock(return_value=(mock_mean, 0)) - surrogate.device = torch.device("cpu") - surrogate.dtype = torch.double - runner = SurrogateRunner( - name="test runner", - _surrogate=surrogate, - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=noise_std, - ) - self.assertEqual(runner.name, "test runner") - self.assertIs(runner.surrogate, surrogate) - self.assertEqual(runner.outcome_names, ["dummy_metric"]) - self.assertEqual(runner.noise_stds, noise_std) - - def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate_legacy().runner - - self.assertIsNone(runner._surrogate) - self.assertIsNone(runner._datasets) - - # Accessing `surrogate` sets datasets and surrogate - self.assertIsInstance(runner.surrogate, TorchModelBridge) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate_legacy().runner - self.assertIsInstance(runner.datasets, list) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - with patch.object( - runner, - "get_surrogate_and_datasets", - wraps=runner.get_surrogate_and_datasets, - ) as mock_get_surrogate_and_datasets: - runner.surrogate - mock_get_surrogate_and_datasets.assert_not_called() - - def test_instantiation_raises_with_missing_args(self) -> None: - with self.assertRaisesRegex( - ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " - ): - SurrogateRunner(name="test runner", outcome_names=[], noise_stds=0.0) - - def test_equality(self) -> None: - def _construct_runner(name: str) -> SurrogateRunner: - return SurrogateRunner( - name=name, - _surrogate=MagicMock(), - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=0.0, - ) - - runner_1 = _construct_runner("test 1") - runner_2 = _construct_runner("test 2") - runner_1a = _construct_runner("test 1") - self.assertEqual(runner_1, runner_1a) - self.assertNotEqual(runner_1, runner_2) - self.assertNotEqual(runner_1, 1) diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 0cdb49c32b6..a798c30b1fe 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -370,8 +370,6 @@ "SumConstraint": SumConstraint, "Surrogate": Surrogate, "SurrogateMetric": BenchmarkMetric, # backward-compatiblity - # NOTE: SurrogateRunners -> SyntheticRunner on load due to complications - "SurrogateRunner": SyntheticRunner, "SobolQMCNormalSampler": SobolQMCNormalSampler, "SyntheticRunner": SyntheticRunner, "SurrogateSpec": SurrogateSpec, diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 60244511b6b..f5a7a72b717 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -15,12 +15,11 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult -from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem from ax.benchmark.runners.botorch_test import ( ParamBasedTestProblem, ParamBasedTestProblemRunner, ) -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -129,41 +128,7 @@ def get_soo_surrogate() -> BenchmarkProblem: ) -def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: - experiment = get_branin_experiment(with_completed_trial=True) - surrogate = TorchModelBridge( - experiment=experiment, - search_space=experiment.search_space, - model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), - data=experiment.lookup_data(), - transforms=[], - ) - runner = SurrogateRunner( - name="test", - outcome_names=["branin"], - get_surrogate_and_datasets=lambda: (surrogate, []), - ) - - observe_noise_sd = True - objective = Objective( - metric=BenchmarkMetric( - name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd - ), - ) - optimization_config = OptimizationConfig(objective=objective) - - return SurrogateBenchmarkProblem( - name="test", - search_space=experiment.search_space, - optimization_config=optimization_config, - num_trials=6, - observe_noise_stds=observe_noise_sd, - optimal_value=0.0, - runner=runner, - ) - - -def get_moo_surrogate() -> SurrogateBenchmarkProblem: +def get_moo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment, @@ -173,11 +138,15 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: transforms=[], ) - runner = SurrogateRunner( + outcome_names = ["branin_a", "branin_b"] + test_function = SurrogateTestFunction( name="test", - outcome_names=["branin_a", "branin_b"], + outcome_names=outcome_names, get_surrogate_and_datasets=lambda: (surrogate, []), ) + runner = ParamBasedTestProblemRunner( + test_problem=test_function, outcome_names=outcome_names + ) observe_noise_sd = True optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective( @@ -199,7 +168,7 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: ], ) ) - return SurrogateBenchmarkProblem( + return BenchmarkProblem( name="test", search_space=experiment.search_space, optimization_config=optimization_config,