diff --git a/ax/benchmark/runners/surrogate.py b/ax/benchmark/runners/surrogate.py index 42c39a1bbbf..3a8bf83ba4c 100644 --- a/ax/benchmark/runners/surrogate.py +++ b/ax/benchmark/runners/surrogate.py @@ -11,9 +11,11 @@ import torch from ax.benchmark.runners.base import BenchmarkRunner +from ax.benchmark.runners.botorch_test import ParamBasedTestProblem from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.observation import ObservationFeatures from ax.core.search_space import SearchSpaceDigest +from ax.core.types import TParamValue from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker @@ -22,6 +24,84 @@ from torch import Tensor +@dataclass(kw_only=True) +class SurrogateTestFunction(ParamBasedTestProblem): + """ + Data-generating function for surrogate benchmark problems. + + Args: + name: The name of the runner. + outcome_names: Names of outcomes to return in `evaluate_true`, if the + surrogate produces more outcomes than are needed. + _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use + for generating observations. If `None`, `get_surrogate_and_datasets` + must not be None and will be used to generate the surrogate when it + is needed. + _datasets: Either `None`, or the `SupervisedDataset`s used to fit + the surrogate model. If `None`, `get_surrogate_and_datasets` must + not be None and will be used to generate the datasets when they are + needed. + get_surrogate_and_datasets: Function that returns the surrogate and + datasets, to allow for lazy construction. If + `get_surrogate_and_datasets` is not provided, `surrogate` and + `datasets` must be provided, and vice versa. + """ + + name: str + outcome_names: list[str] + _surrogate: TorchModelBridge | None = None + _datasets: list[SupervisedDataset] | None = None + get_surrogate_and_datasets: ( + None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]] + ) = None + + def __post_init__(self) -> None: + if self.get_surrogate_and_datasets is None and ( + self._surrogate is None or self._datasets is None + ): + raise ValueError( + "If `get_surrogate_and_datasets` is None, `_surrogate` " + "and `_datasets` must not be None, and vice versa." + ) + + def set_surrogate_and_datasets(self) -> None: + self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)() + + @property + def surrogate(self) -> TorchModelBridge: + if self._surrogate is None: + self.set_surrogate_and_datasets() + return none_throws(self._surrogate) + + @property + def datasets(self) -> list[SupervisedDataset]: + if self._datasets is None: + self.set_surrogate_and_datasets() + return none_throws(self._datasets) + + def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor: + # We're ignoring the uncertainty predictions of the surrogate model here and + # use the mean predictions as the outcomes (before potentially adding noise) + means, _ = self.surrogate.predict( + # pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict + observation_features=[ObservationFeatures(params)] + ) + means = [means[name][0] for name in self.outcome_names] + return torch.tensor( + means, + device=self.surrogate.device, + dtype=self.surrogate.dtype, + ) + + @equality_typechecker + def __eq__(self, other: Base) -> bool: + if type(other) is not type(self): + return False + + # Don't check surrogate, datasets, or callable + return self.name == other.name + + @dataclass class SurrogateRunner(BenchmarkRunner): """Runner for surrogate benchmark problems. diff --git a/ax/benchmark/tests/runners/test_botorch_test_problem.py b/ax/benchmark/tests/runners/test_botorch_test_problem.py index 8f89f2105c1..02d183722a8 100644 --- a/ax/benchmark/tests/runners/test_botorch_test_problem.py +++ b/ax/benchmark/tests/runners/test_botorch_test_problem.py @@ -7,9 +7,10 @@ # pyre-strict +from contextlib import nullcontext from dataclasses import replace from itertools import product -from unittest.mock import Mock +from unittest.mock import Mock, patch import numpy as np @@ -19,13 +20,17 @@ BoTorchTestProblem, ParamBasedTestProblemRunner, ) +from ax.benchmark.runners.surrogate import SurrogateTestFunction from ax.core.arm import Arm from ax.core.base_trial import TrialStatus from ax.core.trial import Trial from ax.exceptions.core import UnsupportedError from ax.utils.common.testutils import TestCase from ax.utils.common.typeutils import checked_cast -from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_test_function, + TestParamBasedTestProblem, +) from botorch.test_functions.multi_objective import BraninCurrin from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann from botorch.utils.transforms import normalize @@ -130,34 +135,35 @@ def test_synthetic_runner(self) -> None: for num_outcomes in (1, 2) for noise_std in (0.0, [float(i) for i in range(num_outcomes)]) ] - for test_problem, noise_std, num_outcomes in botorch_cases + param_based_cases: - is_constrained = isinstance( - test_problem, BoTorchTestProblem - ) and isinstance(test_problem.botorch_problem, ConstrainedHartmann) - num_constraints = 1 if is_constrained else 0 - outcome_names = [ - f"objective_{i}" for i in range(num_outcomes - num_constraints) - ] + ["constraint"] * num_constraints + surrogate_cases = [ + (get_soo_surrogate_test_function(lazy=False), noise_std, 1) + for noise_std in (0.0, 1.0, [0.0], [1.0]) + ] + for test_problem, noise_std, num_outcomes in ( + botorch_cases + param_based_cases + surrogate_cases + ): + # Set up outcome names + if isinstance(test_problem, BoTorchTestProblem): + if isinstance(test_problem.botorch_problem, ConstrainedHartmann): + outcome_names = ["objective_0", "constraint"] + else: + outcome_names = ["objective_0"] + elif isinstance(test_problem, TestParamBasedTestProblem): + outcome_names = [f"objective_{i}" for i in range(num_outcomes)] + elif isinstance(test_problem, SurrogateTestFunction): + outcome_names = ["branin"] + # Set up runner runner = ParamBasedTestProblemRunner( test_problem=test_problem, outcome_names=outcome_names, noise_std=noise_std, ) - modified_bounds = ( - test_problem.modified_bounds - if isinstance(test_problem, BoTorchTestProblem) - else None - ) - - test_description: str = ( - f"test problem: {test_problem.__class__.__name__}, " - f"modified_bounds: {modified_bounds}, " - f"noise_std: {noise_std}." - ) - is_botorch = isinstance(test_problem, BoTorchTestProblem) - with self.subTest(f"Test basic construction, {test_description}"): + test_description = f"{test_problem=}, {noise_std=}" + with self.subTest( + f"Test basic construction, {test_problem=}, {noise_std=}" + ): self.assertIs(runner.test_problem, test_problem) self.assertEqual(runner.outcome_names, outcome_names) if isinstance(noise_std, list): @@ -183,6 +189,7 @@ def test_synthetic_runner(self) -> None: test_problem.botorch_problem.bounds.dtype, torch.double ) + is_botorch = isinstance(test_problem, BoTorchTestProblem) with self.subTest(f"test `get_Y_true()`, {test_description}"): dim = 6 if is_botorch else 9 X = torch.rand(1, dim, dtype=torch.double) @@ -195,7 +202,20 @@ def test_synthetic_runner(self) -> None: ) params = dict(zip(param_names, (x.item() for x in X.unbind(-1)))) - Y = runner.get_Y_true(params=params) + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"branin": [4.2]}, None), + ) + ): + Y = runner.get_Y_true(params=params) + oracle = runner.evaluate_oracle(parameters=params) + if ( isinstance(test_problem, BoTorchTestProblem) and test_problem.modified_bounds is not None @@ -221,12 +241,13 @@ def test_synthetic_runner(self) -> None: ) else: expected_Y = obj + elif isinstance(test_problem, SurrogateTestFunction): + expected_Y = torch.tensor([4.2], dtype=torch.double) else: expected_Y = torch.full( torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double ) self.assertTrue(torch.allclose(Y, expected_Y)) - oracle = runner.evaluate_oracle(parameters=params) self.assertTrue(np.equal(Y.numpy(), oracle).all()) with self.subTest(f"test `run()`, {test_description}"): @@ -237,7 +258,19 @@ def test_synthetic_runner(self) -> None: trial.arms = [arm] trial.arm = arm trial.index = 0 - res = runner.run(trial=trial) + + with ( + nullcontext() + if not isinstance(test_problem, SurrogateTestFunction) + else patch.object( + # pyre-fixme: ParamBasedTestProblem` has no attribute + # `_surrogate`. + runner.test_problem._surrogate, + "predict", + return_value=({"branin": [4.2]}, None), + ) + ): + res = runner.run(trial=trial) self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys()) self.assertEqual({"0_0"}, res["Ys"].keys()) diff --git a/ax/benchmark/tests/runners/test_surrogate_runner.py b/ax/benchmark/tests/runners/test_surrogate_runner.py index e0046c32433..099d09f2d5e 100644 --- a/ax/benchmark/tests/runners/test_surrogate_runner.py +++ b/ax/benchmark/tests/runners/test_surrogate_runner.py @@ -8,12 +8,82 @@ from unittest.mock import MagicMock, patch import torch -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.torch import TorchModelBridge from ax.utils.common.testutils import TestCase -from ax.utils.testing.benchmark_stubs import get_soo_surrogate +from ax.utils.testing.benchmark_stubs import ( + get_soo_surrogate_legacy, + get_soo_surrogate_test_function, +) + + +class TestSurrogateTestFunction(TestCase): + def test_surrogate_test_function(self) -> None: + # Construct a search space with log-scale parameters. + for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}): + with self.subTest(noise_std=noise_std): + surrogate = MagicMock() + mock_mean = torch.tensor([[0.1234]], dtype=torch.double) + surrogate.predict = MagicMock(return_value=(mock_mean, 0)) + surrogate.device = torch.device("cpu") + surrogate.dtype = torch.double + test_function = SurrogateTestFunction( + name="test test function", + outcome_names=["dummy metric"], + _surrogate=surrogate, + _datasets=[], + ) + self.assertEqual(test_function.name, "test test function") + self.assertIs(test_function.surrogate, surrogate) + + def test_lazy_instantiation(self) -> None: + test_function = get_soo_surrogate_test_function() + + self.assertIsNone(test_function._surrogate) + self.assertIsNone(test_function._datasets) + + # Accessing `surrogate` sets datasets and surrogate + self.assertIsInstance(test_function.surrogate, TorchModelBridge) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + # Accessing `datasets` also sets datasets and surrogate + test_function = get_soo_surrogate_test_function() + self.assertIsInstance(test_function.datasets, list) + self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function._datasets, list) + + with patch.object( + test_function, + "get_surrogate_and_datasets", + wraps=test_function.get_surrogate_and_datasets, + ) as mock_get_surrogate_and_datasets: + test_function.surrogate + mock_get_surrogate_and_datasets.assert_not_called() + + def test_instantiation_raises_with_missing_args(self) -> None: + with self.assertRaisesRegex( + ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and " + ): + SurrogateTestFunction(name="test runner", outcome_names=[]) + + def test_equality(self) -> None: + def _construct_test_function(name: str) -> SurrogateTestFunction: + return SurrogateTestFunction( + name=name, + _surrogate=MagicMock(), + _datasets=[], + outcome_names=["dummy_metric"], + ) + + runner_1 = _construct_test_function("test 1") + runner_2 = _construct_test_function("test 2") + runner_1a = _construct_test_function("test 1") + self.assertEqual(runner_1, runner_1a) + self.assertNotEqual(runner_1, runner_2) + self.assertNotEqual(runner_1, 1) class TestSurrogateRunner(TestCase): @@ -49,7 +119,7 @@ def test_surrogate_runner(self) -> None: self.assertEqual(runner.noise_stds, noise_std) def test_lazy_instantiation(self) -> None: - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsNone(runner._surrogate) self.assertIsNone(runner._datasets) @@ -60,7 +130,7 @@ def test_lazy_instantiation(self) -> None: self.assertIsInstance(runner._datasets, list) # Accessing `datasets` also sets datasets and surrogate - runner = get_soo_surrogate().runner + runner = get_soo_surrogate_legacy().runner self.assertIsInstance(runner.datasets, list) self.assertIsInstance(runner._surrogate, TorchModelBridge) self.assertIsInstance(runner._datasets, list) diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index 0c1a37c5a47..60244511b6b 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -16,8 +16,11 @@ from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem -from ax.benchmark.runners.botorch_test import ParamBasedTestProblem -from ax.benchmark.runners.surrogate import SurrogateRunner +from ax.benchmark.runners.botorch_test import ( + ParamBasedTestProblem, + ParamBasedTestProblemRunner, +) +from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction from ax.core.experiment import Experiment from ax.core.objective import MultiObjective, Objective from ax.core.optimization_config import ( @@ -75,7 +78,58 @@ def get_multi_objective_benchmark_problem( ) -def get_soo_surrogate() -> SurrogateBenchmarkProblem: +def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: + experiment = get_branin_experiment(with_completed_trial=True) + surrogate = TorchModelBridge( + experiment=experiment, + search_space=experiment.search_space, + model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + data=experiment.lookup_data(), + transforms=[], + ) + if lazy: + test_function = SurrogateTestFunction( + outcome_names=["branin"], + name="test", + get_surrogate_and_datasets=lambda: (surrogate, []), + ) + else: + test_function = SurrogateTestFunction( + outcome_names=["branin"], + name="test", + _surrogate=surrogate, + _datasets=[], + ) + return test_function + + +def get_soo_surrogate() -> BenchmarkProblem: + experiment = get_branin_experiment(with_completed_trial=True) + test_function = get_soo_surrogate_test_function() + runner = ParamBasedTestProblemRunner( + test_problem=test_function, outcome_names=["branin"] + ) + + observe_noise_sd = True + objective = Objective( + metric=BenchmarkMetric( + name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd + ), + ) + optimization_config = OptimizationConfig(objective=objective) + + return BenchmarkProblem( + name="test", + search_space=experiment.search_space, + optimization_config=optimization_config, + num_trials=6, + observe_noise_stds=observe_noise_sd, + optimal_value=0.0, + runner=runner, + ) + + +def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem: experiment = get_branin_experiment(with_completed_trial=True) surrogate = TorchModelBridge( experiment=experiment,