diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py deleted file mode 100644 index 6c188540801..00000000000 --- a/ax/benchmark/problems/surrogate.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict -""" -Benchmark problem based on surrogate. - -This problem class might appear to function identically to its non-surrogate -counterpart, `BenchmarkProblem`, aside from the restriction that its runners is -of type `SurrogateRunner`. However, it is treated specially within JSON storage -because surrogates cannot be easily serialized. -""" - -from dataclasses import dataclass, field - -from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.runners.surrogate import SurrogateRunner - - -@dataclass(kw_only=True) -class SurrogateBenchmarkProblem(BenchmarkProblem): - """ - Benchmark problem whose `runner` is a `SurrogateRunner`. - - `SurrogateRunner` allows for the surrogate to be constructed lazily and for - datasets to be downloaded lazily. - - For argument descriptions, see `BenchmarkProblem`. - """ - - runner: SurrogateRunner = field(repr=False) diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 3a8bf83ba4c..ac562506d5d 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,21 +6,17 @@ # pyre-strict from collections.abc import Callable, Mapping -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass import torch -from ax.benchmark.runners.base import BenchmarkRunner from ax.benchmark.runners.botorch_test import ParamBasedTestProblem -from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures -from ax.core.search_space import SearchSpaceDigest from ax.core.types import TParamValue from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from botorch.utils.datasets import SupervisedDataset -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws from torch import Tensor @@ -100,137 +96,3 @@ def __eq__(self, other: Base) -> bool: # Don't check surrogate, datasets, or callable return self.name == other.name - - -@dataclass -class SurrogateRunner(BenchmarkRunner): - """Runner for surrogate benchmark problems. - - Args: - name: The name of the runner. - outcome_names: The names of the outcomes of the Surrogate. - _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use - for generating observations. If `None`, `get_surrogate_and_datasets` - must not be None and will be used to generate the surrogate when it - is needed. - _datasets: Either `None`, or the `SupervisedDataset`s used to fit - the surrogate model. If `None`, `get_surrogate_and_datasets` must - not be None and will be used to generate the datasets when they are - needed. - noise_stds: Noise standard deviations to add to the surrogate output(s). - If a single float is provided, noise with that standard deviation - is added to all outputs. Alternatively, a dictionary mapping outcome - names to noise standard deviations can be provided to specify different - noise levels for different outputs. - get_surrogate_and_datasets: Function that returns the surrogate and - datasets, to allow for lazy construction. If - `get_surrogate_and_datasets` is not provided, `surrogate` and - `datasets` must be provided, and vice versa. - search_space_digest: Used to get the target task and fidelity at - which the oracle is evaluated. - """ - - name: str - _surrogate: TorchModelBridge | None = None - _datasets: list[SupervisedDataset] | None = None - noise_stds: float | dict[str, float] = 0.0 - get_surrogate_and_datasets: ( - None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] - ) = None - statuses: dict[int, TrialStatus] = field(default_factory=dict) - - def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None: - super().__post_init__(search_space_digest=search_space_digest) - if self.get_surrogate_and_datasets is None and ( - self._surrogate is None or self._datasets is None - ): - raise ValueError( - "If `get_surrogate_and_datasets` is None, `_surrogate` " - "and `_datasets` must not be None, and vice versa." - ) - - def set_surrogate_and_datasets(self) -> None: - self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() - - @property - def surrogate(self) -> TorchModelBridge: - if self._surrogate is None: - self.set_surrogate_and_datasets() - return none_throws(self._surrogate) - - @property - def datasets(self) -> list[SupervisedDataset]: - if self._datasets is None: - self.set_surrogate_and_datasets() - return none_throws(self._datasets) - - def get_noise_stds(self) -> dict[str, float]: - noise_std = self.noise_stds - if isinstance(noise_std, float): - return {name: noise_std for name in self.outcome_names} - return noise_std - - # pyre-fixme[14]: Inconsistent override - def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor: - # We're ignoring the uncertainty predictions of the surrogate model here and - # use the mean predictions as the outcomes (before potentially adding noise) - means, _ = self.surrogate.predict( - # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict - observation_features=[ObservationFeatures(params)] - ) - means = [means[name][0] for name in self.outcome_names] - return torch.tensor( - means, - device=self.surrogate.device, - dtype=self.surrogate.dtype, - ) - - def run(self, trial: BaseTrial) -> dict[str, Any]: - """Run the trial by evaluating its parameterization(s) on the surrogate model. - - Note: This also sets the status of the trial to COMPLETED. - - Args: - trial: The trial to evaluate. - - Returns: - A dictionary with the following keys: - - outcome_names: The names of the metrics being evaluated. - - Ys: A dict mapping arm names to lists of corresponding outcomes, - where the order of the outcomes is the same as in `outcome_names`. - - Ystds: A dict mapping arm names to lists of corresponding outcome - noise standard deviations (possibly nan if the noise level is - unobserved), where the order of the outcomes is the same as in - `outcome_names`. - - Ys_true: A dict mapping arm names to lists of corresponding ground - truth outcomes, where the order of the outcomes is the same as - in `outcome_names`. - """ - self.statuses[trial.index] = TrialStatus.COMPLETED - run_metadata = super().run(trial=trial) - run_metadata["outcome_names"] = self.outcome_names - return run_metadata - - @property - def is_noiseless(self) -> bool: - if self.noise_stds is None: - return True - if isinstance(self.noise_stds, float): - return self.noise_stds == 0.0 - return all( - std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values() - ) - - @equality_typechecker - def __eq__(self, other: Base) -> bool: - if type(other) is not type(self): - return False - - # Don't check surrogate, datasets, or callable - return ( - (self.name == other.name) - and (self.outcome_names == other.outcome_names) - and (self.noise_stds == other.noise_stds) - # pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`. - and (self.search_space_digest == other.search_space_digest) - ) diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index 099d09f2d5e..9ad256f7f90 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,15 +8,10 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction -from ax.core.parameter import ParameterType, RangeParameter -from ax.core.search_space import SearchSpace +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import ( - get_soo_surrogate_legacy, - get_soo_surrogate_test_function, -) +from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function class TestSurrogateTestFunction(TestCase): @@ -84,84 +79,3 @@ def _construct_test_function(name: str) -> SurrogateTestFunction: self.assertEqual(runner_1, runner_1a) self.assertNotEqual(runner_1, runner_2) self.assertNotEqual(runner_1, 1) - - -class TestSurrogateRunner(TestCase): - def setUp(self) -> None: - super().setUp() - self.search_space = SearchSpace( - parameters=[ - RangeParameter("x", ParameterType.FLOAT, 0.0, 5.0), - RangeParameter("y", ParameterType.FLOAT, 1.0, 10.0, log_scale=True), - RangeParameter("z", ParameterType.INT, 1.0, 5.0, log_scale=True), - ] - ) - - def test_surrogate_runner(self) -> None: - # Construct a search space with log-scale parameters. - for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): - with self.subTest(noise_std=noise_std): - surrogate = MagicMock() - mock_mean = torch.tensor([[0.1234]], dtype=torch.double) - surrogate.predict = MagicMock(return_value=(mock_mean, 0)) - surrogate.device = torch.device("cpu") - surrogate.dtype = torch.double - runner = SurrogateRunner( - name="test runner", - _surrogate=surrogate, - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=noise_std, - ) - self.assertEqual(runner.name, "test runner") - self.assertIs(runner.surrogate, surrogate) - self.assertEqual(runner.outcome_names, ["dummy_metric"]) - self.assertEqual(runner.noise_stds, noise_std) - - def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate_legacy().runner - - self.assertIsNone(runner._surrogate) - self.assertIsNone(runner._datasets) - - # Accessing `surrogate` sets datasets and surrogate - self.assertIsInstance(runner.surrogate, TorchModelBridge) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate_legacy().runner - self.assertIsInstance(runner.datasets, list) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - with patch.object( - runner, - "get_surrogate_and_datasets", - wraps=runner.get_surrogate_and_datasets, - ) as mock_get_surrogate_and_datasets: - runner.surrogate - mock_get_surrogate_and_datasets.assert_not_called() - - def test_instantiation_raises_with_missing_args(self) -> None: - with self.assertRaisesRegex( - ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " - ): - SurrogateRunner(name="test runner", outcome_names=[], noise_stds=0.0) - - def test_equality(self) -> None: - def _construct_runner(name: str) -> SurrogateRunner: - return SurrogateRunner( - name=name, - _surrogate=MagicMock(), - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=0.0, - ) - - runner_1 = _construct_runner("test 1") - runner_2 = _construct_runner("test 2") - runner_1a = _construct_runner("test 1") - self.assertEqual(runner_1, runner_1a) - self.assertNotEqual(runner_1, runner_2) - self.assertNotEqual(runner_1, 1) diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 0cdb49c32b6..a798c30b1fe 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -370,8 +370,6 @@ "SumConstraint": SumConstraint, "Surrogate": Surrogate, "SurrogateMetric": BenchmarkMetric, # backward-compatiblity - # NOTE: SurrogateRunners -> SyntheticRunner on load due to complications - "SurrogateRunner": SyntheticRunner, "SobolQMCNormalSampler": SobolQMCNormalSampler, "SyntheticRunner": SyntheticRunner, "SurrogateSpec": SurrogateSpec, diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 60244511b6b..f5a7a72b717 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -15,12 +15,11 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult -from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem from ax.benchmark.runners.botorch_test import ( ParamBasedTestProblem, ParamBasedTestProblemRunner, ) -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -129,41 +128,7 @@ def get_soo_surrogate() -> BenchmarkProblem: ) -def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: - experiment = get_branin_experiment(with_completed_trial=True) - surrogate = TorchModelBridge( - experiment=experiment, - search_space=experiment.search_space, - model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), - data=experiment.lookup_data(), - transforms=[], - ) - runner = SurrogateRunner( - name="test", - outcome_names=["branin"], - get_surrogate_and_datasets=lambda: (surrogate, []), - ) - - observe_noise_sd = True - objective = Objective( - metric=BenchmarkMetric( - name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd - ), - ) - optimization_config = OptimizationConfig(objective=objective) - - return SurrogateBenchmarkProblem( - name="test", - search_space=experiment.search_space, - optimization_config=optimization_config, - num_trials=6, - observe_noise_stds=observe_noise_sd, - optimal_value=0.0, - runner=runner, - ) - - -def get_moo_surrogate() -> SurrogateBenchmarkProblem: +def get_moo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment, @@ -173,11 +138,15 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: transforms=[], ) - runner = SurrogateRunner( + outcome_names = ["branin_a", "branin_b"] + test_function = SurrogateTestFunction( name="test", - outcome_names=["branin_a", "branin_b"], + outcome_names=outcome_names, get_surrogate_and_datasets=lambda: (surrogate, []), ) + runner = ParamBasedTestProblemRunner( + test_problem=test_function, outcome_names=outcome_names + ) observe_noise_sd = True optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective( @@ -199,7 +168,7 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: ], ) ) - return SurrogateBenchmarkProblem( + return BenchmarkProblem( name="test", search_space=experiment.search_space, optimization_config=optimization_config,