diff --git a/ax/benchmark/benchmark_problem.py b/ax/benchmark/benchmark_problem.py index 17e118446d2..234fe69a154 100644 --- a/ax/benchmark/benchmark_problem.py +++ b/ax/benchmark/benchmark_problem.py @@ -13,10 +13,7 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.runners.base import BenchmarkRunner -from ax.benchmark.runners.botorch_test import ( - BoTorchTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.botorch_test import BoTorchTestProblem from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective @@ -380,7 +377,7 @@ def create_problem_from_botorch( name=name, search_space=search_space, optimization_config=optimization_config, - runner=ParamBasedTestProblemRunner( + runner=BenchmarkRunner( test_problem=BoTorchTestProblem(botorch_problem=test_problem), outcome_names=outcome_names, search_space_digest=extract_search_space_digest( diff --git a/ax/benchmark/problems/hpo/torchvision.py b/ax/benchmark/problems/hpo/torchvision.py index 1a6e424a5f2..3f7f7008537 100644 --- a/ax/benchmark/problems/hpo/torchvision.py +++ b/ax/benchmark/problems/hpo/torchvision.py @@ -14,10 +14,8 @@ BenchmarkProblem, get_soo_config_and_outcome_names, ) -from ax.benchmark.runners.botorch_test import ( - ParamBasedTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError @@ -216,7 +214,7 @@ def get_pytorch_cnn_torchvision_benchmark_problem( observe_noise_sd=False, objective_name="accuracy", ) - runner = ParamBasedTestProblemRunner( + runner = BenchmarkRunner( test_problem=PyTorchCNNTorchvisionParamBasedProblem(name=name), outcome_names=outcome_names, ) diff --git a/ax/benchmark/problems/surrogate.py b/ax/benchmark/problems/surrogate.py deleted file mode 100644 index 6c188540801..00000000000 --- a/ax/benchmark/problems/surrogate.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-strict -""" -Benchmark problem based on surrogate. - -This problem class might appear to function identically to its non-surrogate -counterpart, `BenchmarkProblem`, aside from the restriction that its runners is -of type `SurrogateRunner`. However, it is treated specially within JSON storage -because surrogates cannot be easily serialized. -""" - -from dataclasses import dataclass, field - -from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.runners.surrogate import SurrogateRunner - - -@dataclass(kw_only=True) -class SurrogateBenchmarkProblem(BenchmarkProblem): - """ - Benchmark problem whose `runner` is a `SurrogateRunner`. - - `SurrogateRunner` allows for the surrogate to be constructed lazily and for - datasets to be downloaded lazily. - - For argument descriptions, see `BenchmarkProblem`. - """ - - runner: SurrogateRunner = field(repr=False) diff --git a/ax/benchmark/problems/synthetic/discretized/mixed_integer.py b/ax/benchmark/problems/synthetic/discretized/mixed_integer.py index 36359e9076e..b3d66ef062b 100644 --- a/ax/benchmark/problems/synthetic/discretized/mixed_integer.py +++ b/ax/benchmark/problems/synthetic/discretized/mixed_integer.py @@ -21,10 +21,8 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.runners.botorch_test import ( - BoTorchTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import BoTorchTestProblem from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ParameterType, RangeParameter @@ -104,7 +102,7 @@ def _get_problem_from_common_inputs( test_problem = test_problem_class(dim=dim) else: test_problem = test_problem_class(dim=dim, bounds=test_problem_bounds) - runner = ParamBasedTestProblemRunner( + runner = BenchmarkRunner( test_problem=BoTorchTestProblem( botorch_problem=test_problem, modified_bounds=bounds ), diff --git a/ax/benchmark/problems/synthetic/hss/jenatton.py b/ax/benchmark/problems/synthetic/hss/jenatton.py index fc2b40e6698..69cd824759d 100644 --- a/ax/benchmark/problems/synthetic/hss/jenatton.py +++ b/ax/benchmark/problems/synthetic/hss/jenatton.py @@ -11,10 +11,8 @@ import torch from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem -from ax.benchmark.runners.botorch_test import ( - ParamBasedTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter @@ -120,7 +118,7 @@ def get_jenatton_benchmark_problem( name=name, search_space=search_space, optimization_config=optimization_config, - runner=ParamBasedTestProblemRunner( + runner=BenchmarkRunner( test_problem=Jenatton(), outcome_names=[name], noise_std=noise_std ), num_trials=num_trials, diff --git a/ax/benchmark/runners/base.py b/ax/benchmark/runners/base.py index 2415893be7d..3471231ebf2 100644 --- a/ax/benchmark/runners/base.py +++ b/ax/benchmark/runners/base.py @@ -5,7 +5,6 @@ # pyre-strict -from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping from dataclasses import dataclass, field, InitVar from math import sqrt @@ -14,6 +13,7 @@ import numpy.typing as npt import torch +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.batch_trial import BatchTrial from ax.core.runner import Runner @@ -28,7 +28,7 @@ @dataclass(kw_only=True) -class BenchmarkRunner(Runner, ABC): +class BenchmarkRunner(Runner): """ A Runner that produces both observed and ground-truth values. @@ -45,9 +45,19 @@ class BenchmarkRunner(Runner, ABC): - If they are not deterministc, they are not supported. It is not conceptually clear how to benchmark such problems, so we decided to not over-engineer for that before such a use case arrives. + + Args: + outcome_names: The names of the outcomes returned by the problem. + test_problem: A ``ParamBasedTestProblem`` from which to generate + deterministic data before adding noise. + noise_std: The standard deviation of the noise added to the data. Can be + a list or dict to be per-metric. + search_space_digest: Used to extract target fidelity and task. """ outcome_names: list[str] + test_problem: ParamBasedTestProblem + noise_std: float | list[float] | dict[str, float] = 0.0 # pyre-fixme[16]: Pyre doesn't understand InitVars search_space_digest: InitVar[SearchSpaceDigest | None] = None target_fidelity_and_task: Mapping[str, float | int] = field(init=False) @@ -62,12 +72,12 @@ def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None: self.target_fidelity_and_task = {} def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor: - """ - Return the ground truth values for a given arm. + """Evaluates the test problem. - Synthetic noise is added as part of the Runner's `run()` method. + Returns: + An `m`-dim tensor of ground truth (noiseless) evaluations. """ - ... + return torch.atleast_1d(self.test_problem.evaluate_true(params=params)) def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray: """ @@ -83,13 +93,19 @@ def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray: params = {**parameters, **self.target_fidelity_and_task} return self.get_Y_true(params=params).numpy() - @abstractmethod def get_noise_stds(self) -> dict[str, float]: - """ - Return the standard errors for the synthetic noise to be applied to the - observed values. - """ - ... + noise_std = self.noise_std + if isinstance(noise_std, float): + return {name: noise_std for name in self.outcome_names} + elif isinstance(noise_std, dict): + if not set(noise_std.keys()) == set(self.outcome_names): + raise ValueError( + "Noise std must have keys equal to outcome names if given as " + "a dict." + ) + return noise_std + # list of floats + return dict(zip(self.outcome_names, noise_std, strict=True)) def run(self, trial: BaseTrial) -> dict[str, Any]: """Run the trial by evaluating its parameterization(s). diff --git a/ax/benchmark/runners/botorch_test.py b/ax/benchmark/runners/botorch_test.py index de18d01b6eb..321675afab3 100644 --- a/ax/benchmark/runners/botorch_test.py +++ b/ax/benchmark/runners/botorch_test.py @@ -11,7 +11,6 @@ from itertools import islice import torch -from ax.benchmark.runners.base import BenchmarkRunner from ax.core.types import TParamValue from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem from botorch.utils.transforms import normalize, unnormalize @@ -91,48 +90,3 @@ def evaluate_true(self, params: Mapping[str, float | int]) -> torch.Tensor: constraints = self.botorch_problem.evaluate_slack_true(x).view(-1) return torch.cat([objectives, constraints], dim=-1) return objectives - - -@dataclass(kw_only=True) -class ParamBasedTestProblemRunner(BenchmarkRunner): - """ - A Runner for evaluating `ParamBasedTestProblem`s. - - Given a trial, the Runner will use its `test_problem` to evaluate the - problem noiselessly for each arm in the trial, and then add noise as - specified by the `noise_std`. It will return - metadata including the outcome names and values of metrics. - - Args: - outcome_names: The names of the outcomes returned by the problem. - search_space_digest: Used to extract target fidelity and task. - test_problem: A ``ParamBasedTestProblem`` from which to generate - deterministic data before adding noise. - noise_std: The standard deviation of the noise added to the data. Can be - a list to be per-metric. - """ - - test_problem: ParamBasedTestProblem - noise_std: float | list[float] | dict[str, float] = 0.0 - - def get_noise_stds(self) -> dict[str, float]: - noise_std = self.noise_std - if isinstance(noise_std, float): - return {name: noise_std for name in self.outcome_names} - elif isinstance(noise_std, dict): - if not set(noise_std.keys()) == set(self.outcome_names): - raise ValueError( - "Noise std must have keys equal to outcome names if given as " - "a dict." - ) - return noise_std - # list of floats - return dict(zip(self.outcome_names, noise_std, strict=True)) - - def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor: - """Evaluates the test problem. - - Returns: - An `m`-dim tensor of ground truth (noiseless) evaluations. - """ - return torch.atleast_1d(self.test_problem.evaluate_true(params=params)) diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 3a8bf83ba4c..ac562506d5d 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -6,21 +6,17 @@ # pyre-strict from collections.abc import Callable, Mapping -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass import torch -from ax.benchmark.runners.base import BenchmarkRunner from ax.benchmark.runners.botorch_test import ParamBasedTestProblem -from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures -from ax.core.search_space import SearchSpaceDigest from ax.core.types import TParamValue from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from botorch.utils.datasets import SupervisedDataset -from pyre_extensions import assert_is_instance, none_throws +from pyre_extensions import none_throws from torch import Tensor @@ -100,137 +96,3 @@ def __eq__(self, other: Base) -> bool: # Don't check surrogate, datasets, or callable return self.name == other.name - - -@dataclass -class SurrogateRunner(BenchmarkRunner): - """Runner for surrogate benchmark problems. - - Args: - name: The name of the runner. - outcome_names: The names of the outcomes of the Surrogate. - _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use - for generating observations. If `None`, `get_surrogate_and_datasets` - must not be None and will be used to generate the surrogate when it - is needed. - _datasets: Either `None`, or the `SupervisedDataset`s used to fit - the surrogate model. If `None`, `get_surrogate_and_datasets` must - not be None and will be used to generate the datasets when they are - needed. - noise_stds: Noise standard deviations to add to the surrogate output(s). - If a single float is provided, noise with that standard deviation - is added to all outputs. Alternatively, a dictionary mapping outcome - names to noise standard deviations can be provided to specify different - noise levels for different outputs. - get_surrogate_and_datasets: Function that returns the surrogate and - datasets, to allow for lazy construction. If - `get_surrogate_and_datasets` is not provided, `surrogate` and - `datasets` must be provided, and vice versa. - search_space_digest: Used to get the target task and fidelity at - which the oracle is evaluated. - """ - - name: str - _surrogate: TorchModelBridge | None = None - _datasets: list[SupervisedDataset] | None = None - noise_stds: float | dict[str, float] = 0.0 - get_surrogate_and_datasets: ( - None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] - ) = None - statuses: dict[int, TrialStatus] = field(default_factory=dict) - - def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None: - super().__post_init__(search_space_digest=search_space_digest) - if self.get_surrogate_and_datasets is None and ( - self._surrogate is None or self._datasets is None - ): - raise ValueError( - "If `get_surrogate_and_datasets` is None, `_surrogate` " - "and `_datasets` must not be None, and vice versa." - ) - - def set_surrogate_and_datasets(self) -> None: - self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() - - @property - def surrogate(self) -> TorchModelBridge: - if self._surrogate is None: - self.set_surrogate_and_datasets() - return none_throws(self._surrogate) - - @property - def datasets(self) -> list[SupervisedDataset]: - if self._datasets is None: - self.set_surrogate_and_datasets() - return none_throws(self._datasets) - - def get_noise_stds(self) -> dict[str, float]: - noise_std = self.noise_stds - if isinstance(noise_std, float): - return {name: noise_std for name in self.outcome_names} - return noise_std - - # pyre-fixme[14]: Inconsistent override - def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor: - # We're ignoring the uncertainty predictions of the surrogate model here and - # use the mean predictions as the outcomes (before potentially adding noise) - means, _ = self.surrogate.predict( - # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict - observation_features=[ObservationFeatures(params)] - ) - means = [means[name][0] for name in self.outcome_names] - return torch.tensor( - means, - device=self.surrogate.device, - dtype=self.surrogate.dtype, - ) - - def run(self, trial: BaseTrial) -> dict[str, Any]: - """Run the trial by evaluating its parameterization(s) on the surrogate model. - - Note: This also sets the status of the trial to COMPLETED. - - Args: - trial: The trial to evaluate. - - Returns: - A dictionary with the following keys: - - outcome_names: The names of the metrics being evaluated. - - Ys: A dict mapping arm names to lists of corresponding outcomes, - where the order of the outcomes is the same as in `outcome_names`. - - Ystds: A dict mapping arm names to lists of corresponding outcome - noise standard deviations (possibly nan if the noise level is - unobserved), where the order of the outcomes is the same as in - `outcome_names`. - - Ys_true: A dict mapping arm names to lists of corresponding ground - truth outcomes, where the order of the outcomes is the same as - in `outcome_names`. - """ - self.statuses[trial.index] = TrialStatus.COMPLETED - run_metadata = super().run(trial=trial) - run_metadata["outcome_names"] = self.outcome_names - return run_metadata - - @property - def is_noiseless(self) -> bool: - if self.noise_stds is None: - return True - if isinstance(self.noise_stds, float): - return self.noise_stds == 0.0 - return all( - std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values() - ) - - @equality_typechecker - def __eq__(self, other: Base) -> bool: - if type(other) is not type(self): - return False - - # Don't check surrogate, datasets, or callable - return ( - (self.name == other.name) - and (self.outcome_names == other.outcome_names) - and (self.noise_stds == other.noise_stds) - # pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`. - and (self.search_space_digest == other.search_space_digest) - ) diff --git a/ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py b/ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py index 953ea7be3aa..23a2f2737c5 100644 --- a/ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py +++ b/ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py @@ -14,7 +14,6 @@ get_jenatton_benchmark_problem, jenatton_test_function, ) -from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner from ax.core.arm import Arm from ax.core.data import Data from ax.core.experiment import Experiment @@ -106,10 +105,7 @@ def test_create_problem(self) -> None: self.assertEqual(metric.name, "Jenatton") self.assertTrue(objective.minimize) self.assertTrue(metric.lower_is_better) - self.assertEqual( - assert_is_instance(problem.runner, ParamBasedTestProblemRunner).noise_std, - 0.0, - ) + self.assertEqual(problem.runner.noise_std, 0.0) self.assertFalse(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd) problem = get_jenatton_benchmark_problem( @@ -118,10 +114,7 @@ def test_create_problem(self) -> None: objective = problem.optimization_config.objective metric = objective.metric self.assertTrue(metric.lower_is_better) - self.assertEqual( - assert_is_instance(problem.runner, ParamBasedTestProblemRunner).noise_std, - 0.1, - ) + self.assertEqual(problem.runner.noise_std, 0.1) self.assertTrue(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd) def test_fetch_trial_data(self) -> None: diff --git a/ax/benchmark/tests/problems/test_mixed_integer_problems.py b/ax/benchmark/tests/problems/test_mixed_integer_problems.py index 694dfd63a66..c744c50f516 100644 --- a/ax/benchmark/tests/problems/test_mixed_integer_problems.py +++ b/ax/benchmark/tests/problems/test_mixed_integer_problems.py @@ -15,10 +15,7 @@ get_discrete_hartmann, get_discrete_rosenbrock, ) -from ax.benchmark.runners.botorch_test import ( - BoTorchTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.botorch_test import BoTorchTestProblem from ax.core.arm import Arm from ax.core.parameter import ParameterType from ax.core.trial import Trial @@ -37,7 +34,7 @@ def test_problems(self) -> None: name = problem_cls.__name__ problem = constructor() self.assertEqual(f"Discrete {name}", problem.name) - runner = assert_is_instance(problem.runner, ParamBasedTestProblemRunner) + runner = problem.runner test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem) botorch_problem = test_problem.botorch_problem self.assertIsInstance(botorch_problem, problem_cls) @@ -99,7 +96,7 @@ def test_problems(self) -> None: ] for problem, params, expected_arg in cases: - runner = assert_is_instance(problem.runner, ParamBasedTestProblemRunner) + runner = problem.runner test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem) trial = Trial(experiment=MagicMock()) # pyre-fixme: Incompatible parameter type [6]: In call diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index 9dcba5c4872..67ed96f4095 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -16,10 +16,8 @@ import torch from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem -from ax.benchmark.runners.botorch_test import ( - BoTorchTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import BoTorchTestProblem from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.arm import Arm from ax.core.base_trial import TrialStatus @@ -154,7 +152,7 @@ def test_synthetic_runner(self) -> None: outcome_names = ["branin"] # Set up runner - runner = ParamBasedTestProblemRunner( + runner = BenchmarkRunner( test_problem=test_problem, outcome_names=outcome_names, noise_std=noise_std, @@ -291,15 +289,15 @@ def test_synthetic_runner(self) -> None: with self.assertRaisesRegex( UnsupportedError, "serialize_init_args is not a supported method" ): - ParamBasedTestProblemRunner.serialize_init_args(obj=runner) + BenchmarkRunner.serialize_init_args(obj=runner) with self.assertRaisesRegex( UnsupportedError, "deserialize_init_args is not a supported method" ): - ParamBasedTestProblemRunner.deserialize_init_args({}) + BenchmarkRunner.deserialize_init_args({}) def test_botorch_test_problem_runner_heterogeneous_noise(self) -> None: for noise_std in [[0.1, 0.05], {"objective": 0.1, "constraint": 0.05}]: - runner = ParamBasedTestProblemRunner( + runner = BenchmarkRunner( test_problem=BoTorchTestProblem( botorch_problem=ConstrainedHartmann(dim=6) ), diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index 099d09f2d5e..9ad256f7f90 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,15 +8,10 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction -from ax.core.parameter import ParameterType, RangeParameter -from ax.core.search_space import SearchSpace +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import ( - get_soo_surrogate_legacy, - get_soo_surrogate_test_function, -) +from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function class TestSurrogateTestFunction(TestCase): @@ -84,84 +79,3 @@ def _construct_test_function(name: str) -> SurrogateTestFunction: self.assertEqual(runner_1, runner_1a) self.assertNotEqual(runner_1, runner_2) self.assertNotEqual(runner_1, 1) - - -class TestSurrogateRunner(TestCase): - def setUp(self) -> None: - super().setUp() - self.search_space = SearchSpace( - parameters=[ - RangeParameter("x", ParameterType.FLOAT, 0.0, 5.0), - RangeParameter("y", ParameterType.FLOAT, 1.0, 10.0, log_scale=True), - RangeParameter("z", ParameterType.INT, 1.0, 5.0, log_scale=True), - ] - ) - - def test_surrogate_runner(self) -> None: - # Construct a search space with log-scale parameters. - for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): - with self.subTest(noise_std=noise_std): - surrogate = MagicMock() - mock_mean = torch.tensor([[0.1234]], dtype=torch.double) - surrogate.predict = MagicMock(return_value=(mock_mean, 0)) - surrogate.device = torch.device("cpu") - surrogate.dtype = torch.double - runner = SurrogateRunner( - name="test runner", - _surrogate=surrogate, - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=noise_std, - ) - self.assertEqual(runner.name, "test runner") - self.assertIs(runner.surrogate, surrogate) - self.assertEqual(runner.outcome_names, ["dummy_metric"]) - self.assertEqual(runner.noise_stds, noise_std) - - def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate_legacy().runner - - self.assertIsNone(runner._surrogate) - self.assertIsNone(runner._datasets) - - # Accessing `surrogate` sets datasets and surrogate - self.assertIsInstance(runner.surrogate, TorchModelBridge) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate_legacy().runner - self.assertIsInstance(runner.datasets, list) - self.assertIsInstance(runner._surrogate, TorchModelBridge) - self.assertIsInstance(runner._datasets, list) - - with patch.object( - runner, - "get_surrogate_and_datasets", - wraps=runner.get_surrogate_and_datasets, - ) as mock_get_surrogate_and_datasets: - runner.surrogate - mock_get_surrogate_and_datasets.assert_not_called() - - def test_instantiation_raises_with_missing_args(self) -> None: - with self.assertRaisesRegex( - ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " - ): - SurrogateRunner(name="test runner", outcome_names=[], noise_stds=0.0) - - def test_equality(self) -> None: - def _construct_runner(name: str) -> SurrogateRunner: - return SurrogateRunner( - name=name, - _surrogate=MagicMock(), - _datasets=[], - outcome_names=["dummy_metric"], - noise_stds=0.0, - ) - - runner_1 = _construct_runner("test 1") - runner_2 = _construct_runner("test 2") - runner_1a = _construct_runner("test 1") - self.assertEqual(runner_1, runner_1a) - self.assertNotEqual(runner_1, runner_2) - self.assertNotEqual(runner_1, 1) diff --git a/ax/benchmark/tests/test_benchmark_problem.py b/ax/benchmark/tests/test_benchmark_problem.py index 886d9f85cb8..b8f948a1607 100644 --- a/ax/benchmark/tests/test_benchmark_problem.py +++ b/ax/benchmark/tests/test_benchmark_problem.py @@ -14,10 +14,8 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch -from ax.benchmark.runners.botorch_test import ( - BoTorchTestProblem, - ParamBasedTestProblemRunner, -) +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import BoTorchTestProblem from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( MultiObjectiveOptimizationConfig, @@ -54,7 +52,7 @@ def test_inference_value_not_implemented(self) -> None: for name in ["Branin", "Currin"] ] optimization_config = OptimizationConfig(objective=objectives[0]) - runner = ParamBasedTestProblemRunner( + runner = BenchmarkRunner( test_problem=BoTorchTestProblem(botorch_problem=Branin()), outcome_names=["foo"], ) @@ -215,7 +213,7 @@ def _test_constrained_from_botorch( observe_noise_sd=observe_noise_sd, noise_std=noise_std, ) - runner = assert_is_instance(ax_problem.runner, ParamBasedTestProblemRunner) + runner = ax_problem.runner test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem) botorch_problem = assert_is_instance( test_problem.botorch_problem, ConstrainedBaseTestProblem diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 0cdb49c32b6..a798c30b1fe 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -370,8 +370,6 @@ "SumConstraint": SumConstraint, "Surrogate": Surrogate, "SurrogateMetric": BenchmarkMetric, # backward-compatiblity - # NOTE: SurrogateRunners -> SyntheticRunner on load due to complications - "SurrogateRunner": SyntheticRunner, "SobolQMCNormalSampler": SobolQMCNormalSampler, "SyntheticRunner": SyntheticRunner, "SurrogateSpec": SurrogateSpec, diff --git a/ax/storage/tests/test_registry_bundle.py b/ax/storage/tests/test_registry_bundle.py index 4add5eb266b..1c775db8dac 100644 --- a/ax/storage/tests/test_registry_bundle.py +++ b/ax/storage/tests/test_registry_bundle.py @@ -6,7 +6,6 @@ # pyre-strict from ax.benchmark.benchmark_metric import BenchmarkMetric -from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner from ax.metrics.branin import BraninMetric from ax.runners.synthetic import SyntheticRunner from ax.storage.registry_bundle import RegistryBundle @@ -26,7 +25,7 @@ def test_from_registry_bundles(self) -> None: right = RegistryBundle( metric_clss={BenchmarkMetric: None}, - runner_clss={ParamBasedTestProblemRunner: None}, + runner_clss={SyntheticRunner: None}, json_encoder_registry={}, json_class_encoder_registry={}, json_decoder_registry={}, @@ -41,4 +40,3 @@ def test_from_registry_bundles(self) -> None: self.assertIn(BraninMetric, combined.encoder_registry) self.assertIn(SyntheticRunner, combined.encoder_registry) self.assertIn(BenchmarkMetric, combined.encoder_registry) - self.assertIn(ParamBasedTestProblemRunner, combined.encoder_registry) diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 60244511b6b..0c462ac32eb 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -15,12 +15,9 @@ from ax.benchmark.benchmark_metric import BenchmarkMetric from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult -from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem -from ax.benchmark.runners.botorch_test import ( - ParamBasedTestProblem, - ParamBasedTestProblemRunner, -) -from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction +from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -106,9 +103,7 @@ def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: def get_soo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment(with_completed_trial=True) test_function = get_soo_surrogate_test_function() - runner = ParamBasedTestProblemRunner( - test_problem=test_function, outcome_names=["branin"] - ) + runner = BenchmarkRunner(test_problem=test_function, outcome_names=["branin"]) observe_noise_sd = True objective = Objective( @@ -129,41 +124,7 @@ def get_soo_surrogate() -> BenchmarkProblem: ) -def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: - experiment = get_branin_experiment(with_completed_trial=True) - surrogate = TorchModelBridge( - experiment=experiment, - search_space=experiment.search_space, - model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), - data=experiment.lookup_data(), - transforms=[], - ) - runner = SurrogateRunner( - name="test", - outcome_names=["branin"], - get_surrogate_and_datasets=lambda: (surrogate, []), - ) - - observe_noise_sd = True - objective = Objective( - metric=BenchmarkMetric( - name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd - ), - ) - optimization_config = OptimizationConfig(objective=objective) - - return SurrogateBenchmarkProblem( - name="test", - search_space=experiment.search_space, - optimization_config=optimization_config, - num_trials=6, - observe_noise_stds=observe_noise_sd, - optimal_value=0.0, - runner=runner, - ) - - -def get_moo_surrogate() -> SurrogateBenchmarkProblem: +def get_moo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment, @@ -173,11 +134,13 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: transforms=[], ) - runner = SurrogateRunner( + outcome_names = ["branin_a", "branin_b"] + test_function = SurrogateTestFunction( name="test", - outcome_names=["branin_a", "branin_b"], + outcome_names=outcome_names, get_surrogate_and_datasets=lambda: (surrogate, []), ) + runner = BenchmarkRunner(test_problem=test_function, outcome_names=outcome_names) observe_noise_sd = True optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective( @@ -199,7 +162,7 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem: ], ) ) - return SurrogateBenchmarkProblem( + return BenchmarkProblem( name="test", search_space=experiment.search_space, optimization_config=optimization_config,