Skip to content

Commit

Permalink
Introduce SurrogateTestFunction
Browse files Browse the repository at this point in the history
Summary:
**Context**: This next diff will cut over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity.

**Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming.

The name changes I hope to make:
* ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner
* ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate.
* BoTorchTestProblem -> BoTorchTestFunction

**Changes in this diff**:
* Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates, giving it the surrogate-related logic from SurrogateRunner

Differential Revision: D64899032
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 24, 2024
1 parent 10ea5ed commit a35c3e0
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 14 deletions.
92 changes: 92 additions & 0 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.observation import ObservationFeatures
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
Expand All @@ -22,6 +24,96 @@
from torch import Tensor


@dataclass(kw_only=True)
class SurrogateTestFunction(ParamBasedTestProblem):
"""
Data-generating function for surrogate benchmark problems.
Args:
num_objectives: The number of objectives. Defaults to 1.
name: The name of the runner.
outcome_names: Names of outcomes to return in `evaluate_true`, if the
surrogate produces more outcomes than are needed. If None, all
outcomes are returned.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
for generating observations. If `None`, `get_surrogate_and_datasets`
must not be None and will be used to generate the surrogate when it
is needed.
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
the surrogate model. If `None`, `get_surrogate_and_datasets` must
not be None and will be used to generate the datasets when they are
needed.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
"""

name: str
num_objectives: int = 1
outcome_names: list[str] | None = None
_surrogate: TorchModelBridge | None = None
_datasets: list[SupervisedDataset] | None = None
get_surrogate_and_datasets: (
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
) = None

def __post_init__(self) -> None:
if self.get_surrogate_and_datasets is None and (
self._surrogate is None or self._datasets is None
):
raise ValueError(
"If `get_surrogate_and_datasets` is None, `_surrogate` "
"and `_datasets` must not be None, and vice versa."
)
if (
self.outcome_names is not None
and len(self.outcome_names) != self.num_objectives
):
raise ValueError(
f"Number of outcome names ({len(self.outcome_names)}) must match "
f"number of objectives ({self.num_objectives})."
)

def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()

@property
def surrogate(self) -> TorchModelBridge:
if self._surrogate is None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)

@property
def datasets(self) -> list[SupervisedDataset]:
if self._datasets is None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)

def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
)
names = list(means.keys()) if self.outcome_names is None else self.outcome_names
means = [means[name][0] for name in names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Don't check surrogate, datasets, or callable
return self.name == other.name


@dataclass
class SurrogateRunner(BenchmarkRunner):
"""Runner for surrogate benchmark problems.
Expand Down
56 changes: 49 additions & 7 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
# pyre-strict


from contextlib import nullcontext
from dataclasses import replace
from itertools import product
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np

Expand All @@ -19,13 +20,17 @@
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.trial import Trial
from ax.exceptions.core import UnsupportedError
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import checked_cast
from ax.utils.testing.benchmark_stubs import TestParamBasedTestProblem
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_test_function,
TestParamBasedTestProblem,
)
from botorch.test_functions.synthetic import Ackley, ConstrainedHartmann, Hartmann
from botorch.utils.transforms import normalize

Expand Down Expand Up @@ -87,7 +92,13 @@ def test_synthetic_runner(self) -> None:
)
for num_objectives, noise_std in product((1, 2), (None, 0.0, 1.0))
]
for test_problem, noise_std in botorch_cases + param_based_cases:
surrogate_cases = [
(get_soo_surrogate_test_function(lazy=False), noise_std)
for noise_std in (None, 0.0, 1.0)
]
for test_problem, noise_std in (
botorch_cases + param_based_cases + surrogate_cases
):
num_objectives = test_problem.num_objectives

outcome_names = [f"objective_{i}" for i in range(num_objectives)]
Expand Down Expand Up @@ -148,7 +159,20 @@ def test_synthetic_runner(self) -> None:
)
params = dict(zip(param_names, (x.item() for x in X.unbind(-1))))

Y = runner.get_Y_true(params=params)
with (
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"objective_0": [4.2]}, None),
)
):
Y = runner.get_Y_true(params=params)
oracle = runner.evaluate_oracle(parameters=params)

if (
isinstance(test_problem, BoTorchTestProblem)
and test_problem.modified_bounds is not None
Expand Down Expand Up @@ -176,12 +200,13 @@ def test_synthetic_runner(self) -> None:
)
else:
expected_Y = obj
elif isinstance(test_problem, SurrogateTestFunction):
expected_Y = torch.tensor([4.2], dtype=torch.double)
else:
expected_Y = torch.full(
torch.Size([2]), X.pow(2).sum().item(), dtype=torch.double
)
self.assertTrue(torch.allclose(Y, expected_Y))
oracle = runner.evaluate_oracle(parameters=params)
self.assertTrue(np.equal(Y.numpy(), oracle).all())

with self.subTest(f"test `run()`, {test_description}"):
Expand All @@ -192,11 +217,28 @@ def test_synthetic_runner(self) -> None:
trial.arms = [arm]
trial.arm = arm
trial.index = 0
res = runner.run(trial=trial)

with (
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"objective_0": [4.2]}, None),
)
):
res = runner.run(trial=trial)
self.assertEqual({"Ys", "Ystds", "outcome_names"}, res.keys())
self.assertEqual({"0_0"}, res["Ys"].keys())
if noise_std is not None:
if isinstance(noise_std, float):
self.assertEqual(res["Ystds"]["0_0"], [noise_std] * len(Y))
elif isinstance(noise_std, dict):
self.assertEqual(
res["Ystds"]["0_0"],
[noise_std[k] for k in runner.outcome_names],
)
else:
self.assertEqual(res["Ys"]["0_0"], Y.tolist())
self.assertEqual(res["Ystds"]["0_0"], [0.0] * len(Y))
Expand Down
79 changes: 75 additions & 4 deletions ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,83 @@
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.runners.surrogate import SurrogateRunner
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_soo_surrogate
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_legacy,
get_soo_surrogate_test_function,
)


class TestSurrogateTestFunction(TestCase):
def test_surrogate_test_function(self) -> None:
# Construct a search space with log-scale parameters.
for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}):
with self.subTest(noise_std=noise_std):
surrogate = MagicMock()
mock_mean = torch.tensor([[0.1234]], dtype=torch.double)
surrogate.predict = MagicMock(return_value=(mock_mean, 0))
surrogate.device = torch.device("cpu")
surrogate.dtype = torch.double
test_function = SurrogateTestFunction(
name="test test function",
num_objectives=1,
_surrogate=surrogate,
_datasets=[],
)
self.assertEqual(test_function.name, "test test function")
self.assertIs(test_function.surrogate, surrogate)
self.assertEqual(test_function.num_objectives, 1)

def test_lazy_instantiation(self) -> None:
test_function = get_soo_surrogate_test_function()

self.assertIsNone(test_function._surrogate)
self.assertIsNone(test_function._datasets)

# Accessing `surrogate` sets datasets and surrogate
self.assertIsInstance(test_function.surrogate, TorchModelBridge)
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
self.assertIsInstance(test_function._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
test_function = get_soo_surrogate_test_function()
self.assertIsInstance(test_function.datasets, list)
self.assertIsInstance(test_function._surrogate, TorchModelBridge)
self.assertIsInstance(test_function._datasets, list)

with patch.object(
test_function,
"get_surrogate_and_datasets",
wraps=test_function.get_surrogate_and_datasets,
) as mock_get_surrogate_and_datasets:
test_function.surrogate
mock_get_surrogate_and_datasets.assert_not_called()

def test_instantiation_raises_with_missing_args(self) -> None:
with self.assertRaisesRegex(
ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and "
):
SurrogateTestFunction(name="test runner", num_objectives=1)

def test_equality(self) -> None:
def _construct_test_function(name: str) -> SurrogateTestFunction:
return SurrogateTestFunction(
name=name,
_surrogate=MagicMock(),
_datasets=[],
num_objectives=1,
)

runner_1 = _construct_test_function("test 1")
runner_2 = _construct_test_function("test 2")
runner_1a = _construct_test_function("test 1")
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)


class TestSurrogateRunner(TestCase):
Expand Down Expand Up @@ -49,7 +120,7 @@ def test_surrogate_runner(self) -> None:
self.assertEqual(runner.noise_stds, noise_std)

def test_lazy_instantiation(self) -> None:
runner = get_soo_surrogate().runner
runner = get_soo_surrogate_legacy().runner

self.assertIsNone(runner._surrogate)
self.assertIsNone(runner._datasets)
Expand All @@ -60,7 +131,7 @@ def test_lazy_instantiation(self) -> None:
self.assertIsInstance(runner._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
runner = get_soo_surrogate().runner
runner = get_soo_surrogate_legacy().runner
self.assertIsInstance(runner.datasets, list)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)
Expand Down
Loading

0 comments on commit a35c3e0

Please sign in to comment.