Skip to content

Commit

Permalink
Get rid of SurrogateRunner and SurrogateBenchmarkProblem (facebook#2954)
Browse files Browse the repository at this point in the history
Summary:

**Context**:  This diff cuts over over uses of `SurrogateRunner` to use `ParamBasedTestProblemRunner` with a `test_problem` that is the newly introduced `SurrogateTestFunction`, and the following diff after that will bring us down to only one runner class for benchmarking by merging `ParamBasedTestProblemRunner` into `BenchmarkRunner`. Having only one runner will make it easier to enable asynchronous benchmarks. Currently, SurrogateRunner had its own logic for tracking when trials are completed, which would make it difficult to work in with asynchronicity.

**Note on naming**: Some names have become non-intuitive in the process of benchmarking. To accord with some future changes I hope to make, I called a new class SurrogateTestFunction, whereas SurrogateParamBasedTestProblem would be more in line with the old naming.

The name changes I hope to make:
* ParamBasedTestProblemRunner -> nothing, absorbed into BenchmarkRunner
* ParamBasedTestProblem -> TestFunction, to emphasize that all it does is generate data (rather than more generally specify the problem we are solving) and that it is deterministic, and to differentiate it from BenchmarkProblem. BenchmarkTestFunction would also be a candidate.
* BoTorchTestProblem -> BoTorchTestFunction

**Changes in this diff**:
* Introduces SurrogateTestFunction, a ParamBasedTestProblem for surrogates
* Removes SurrogateRunner. There is some loss of (unused) functionality here; dict-valued 'noise_std' is no longer supported. I expect we will eventually want that, but it doesn't need to be done now.
* Changed `predict_for_tensor` utility function to use a `SurrogateTestFunction` rather than a `SurrogateRunner`
* Removed `SurrogateRunner` from the json_store registry. There is no point in keeping it around for backward compatibility since it was just deserializing to a dummy `SyntheticRunner` anyway.

Differential Revision: D64855010
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 25, 2024
1 parent fdfcb10 commit 027cfba
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 297 deletions.
33 changes: 0 additions & 33 deletions ax/benchmark/problems/surrogate.py

This file was deleted.

134 changes: 0 additions & 134 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,137 +100,3 @@ def __eq__(self, other: Base) -> bool:

# Don't check surrogate, datasets, or callable
return self.name == other.name


@dataclass
class SurrogateRunner(BenchmarkRunner):
"""Runner for surrogate benchmark problems.
Args:
name: The name of the runner.
outcome_names: The names of the outcomes of the Surrogate.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
for generating observations. If `None`, `get_surrogate_and_datasets`
must not be None and will be used to generate the surrogate when it
is needed.
_datasets: Either `None`, or the `SupervisedDataset`s used to fit
the surrogate model. If `None`, `get_surrogate_and_datasets` must
not be None and will be used to generate the datasets when they are
needed.
noise_stds: Noise standard deviations to add to the surrogate output(s).
If a single float is provided, noise with that standard deviation
is added to all outputs. Alternatively, a dictionary mapping outcome
names to noise standard deviations can be provided to specify different
noise levels for different outputs.
get_surrogate_and_datasets: Function that returns the surrogate and
datasets, to allow for lazy construction. If
`get_surrogate_and_datasets` is not provided, `surrogate` and
`datasets` must be provided, and vice versa.
search_space_digest: Used to get the target task and fidelity at
which the oracle is evaluated.
"""

name: str
_surrogate: TorchModelBridge | None = None
_datasets: list[SupervisedDataset] | None = None
noise_stds: float | dict[str, float] = 0.0
get_surrogate_and_datasets: (
None | Callable[[], tuple[TorchModelBridge, list[SupervisedDataset]]]
) = None
statuses: dict[int, TrialStatus] = field(default_factory=dict)

def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
super().__post_init__(search_space_digest=search_space_digest)
if self.get_surrogate_and_datasets is None and (
self._surrogate is None or self._datasets is None
):
raise ValueError(
"If `get_surrogate_and_datasets` is None, `_surrogate` "
"and `_datasets` must not be None, and vice versa."
)

def set_surrogate_and_datasets(self) -> None:
self._surrogate, self._datasets = none_throws(self.get_surrogate_and_datasets)()

@property
def surrogate(self) -> TorchModelBridge:
if self._surrogate is None:
self.set_surrogate_and_datasets()
return none_throws(self._surrogate)

@property
def datasets(self) -> list[SupervisedDataset]:
if self._datasets is None:
self.set_surrogate_and_datasets()
return none_throws(self._datasets)

def get_noise_stds(self) -> dict[str, float]:
noise_std = self.noise_stds
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
return noise_std

# pyre-fixme[14]: Inconsistent override
def get_Y_true(self, params: Mapping[str, float | int]) -> Tensor:
# We're ignoring the uncertainty predictions of the surrogate model here and
# use the mean predictions as the outcomes (before potentially adding noise)
means, _ = self.surrogate.predict(
# pyre-fixme[6]: params is a Mapping, but ObservationFeatures expects a Dict
observation_features=[ObservationFeatures(params)]
)
means = [means[name][0] for name in self.outcome_names]
return torch.tensor(
means,
device=self.surrogate.device,
dtype=self.surrogate.dtype,
)

def run(self, trial: BaseTrial) -> dict[str, Any]:
"""Run the trial by evaluating its parameterization(s) on the surrogate model.
Note: This also sets the status of the trial to COMPLETED.
Args:
trial: The trial to evaluate.
Returns:
A dictionary with the following keys:
- outcome_names: The names of the metrics being evaluated.
- Ys: A dict mapping arm names to lists of corresponding outcomes,
where the order of the outcomes is the same as in `outcome_names`.
- Ystds: A dict mapping arm names to lists of corresponding outcome
noise standard deviations (possibly nan if the noise level is
unobserved), where the order of the outcomes is the same as in
`outcome_names`.
- Ys_true: A dict mapping arm names to lists of corresponding ground
truth outcomes, where the order of the outcomes is the same as
in `outcome_names`.
"""
self.statuses[trial.index] = TrialStatus.COMPLETED
run_metadata = super().run(trial=trial)
run_metadata["outcome_names"] = self.outcome_names
return run_metadata

@property
def is_noiseless(self) -> bool:
if self.noise_stds is None:
return True
if isinstance(self.noise_stds, float):
return self.noise_stds == 0.0
return all(
std == 0.0 for std in assert_is_instance(self.noise_stds, dict).values()
)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if type(other) is not type(self):
return False

# Don't check surrogate, datasets, or callable
return (
(self.name == other.name)
and (self.outcome_names == other.outcome_names)
and (self.noise_stds == other.noise_stds)
# pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`.
and (self.search_space_digest == other.search_space_digest)
)
90 changes: 2 additions & 88 deletions ax/benchmark/tests/runners/test_surrogate_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,10 @@
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import (
get_soo_surrogate_legacy,
get_soo_surrogate_test_function,
)
from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function


class TestSurrogateTestFunction(TestCase):
Expand Down Expand Up @@ -84,84 +79,3 @@ def _construct_test_function(name: str) -> SurrogateTestFunction:
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)


class TestSurrogateRunner(TestCase):
def setUp(self) -> None:
super().setUp()
self.search_space = SearchSpace(
parameters=[
RangeParameter("x", ParameterType.FLOAT, 0.0, 5.0),
RangeParameter("y", ParameterType.FLOAT, 1.0, 10.0, log_scale=True),
RangeParameter("z", ParameterType.INT, 1.0, 5.0, log_scale=True),
]
)

def test_surrogate_runner(self) -> None:
# Construct a search space with log-scale parameters.
for noise_std in (0.0, 0.1, {"dummy_metric": 0.2}):
with self.subTest(noise_std=noise_std):
surrogate = MagicMock()
mock_mean = torch.tensor([[0.1234]], dtype=torch.double)
surrogate.predict = MagicMock(return_value=(mock_mean, 0))
surrogate.device = torch.device("cpu")
surrogate.dtype = torch.double
runner = SurrogateRunner(
name="test runner",
_surrogate=surrogate,
_datasets=[],
outcome_names=["dummy_metric"],
noise_stds=noise_std,
)
self.assertEqual(runner.name, "test runner")
self.assertIs(runner.surrogate, surrogate)
self.assertEqual(runner.outcome_names, ["dummy_metric"])
self.assertEqual(runner.noise_stds, noise_std)

def test_lazy_instantiation(self) -> None:
runner = get_soo_surrogate_legacy().runner

self.assertIsNone(runner._surrogate)
self.assertIsNone(runner._datasets)

# Accessing `surrogate` sets datasets and surrogate
self.assertIsInstance(runner.surrogate, TorchModelBridge)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)

# Accessing `datasets` also sets datasets and surrogate
runner = get_soo_surrogate_legacy().runner
self.assertIsInstance(runner.datasets, list)
self.assertIsInstance(runner._surrogate, TorchModelBridge)
self.assertIsInstance(runner._datasets, list)

with patch.object(
runner,
"get_surrogate_and_datasets",
wraps=runner.get_surrogate_and_datasets,
) as mock_get_surrogate_and_datasets:
runner.surrogate
mock_get_surrogate_and_datasets.assert_not_called()

def test_instantiation_raises_with_missing_args(self) -> None:
with self.assertRaisesRegex(
ValueError, "If `get_surrogate_and_datasets` is None, `_surrogate` and "
):
SurrogateRunner(name="test runner", outcome_names=[], noise_stds=0.0)

def test_equality(self) -> None:
def _construct_runner(name: str) -> SurrogateRunner:
return SurrogateRunner(
name=name,
_surrogate=MagicMock(),
_datasets=[],
outcome_names=["dummy_metric"],
noise_stds=0.0,
)

runner_1 = _construct_runner("test 1")
runner_2 = _construct_runner("test 2")
runner_1a = _construct_runner("test 1")
self.assertEqual(runner_1, runner_1a)
self.assertNotEqual(runner_1, runner_2)
self.assertNotEqual(runner_1, 1)
2 changes: 0 additions & 2 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,6 @@
"SumConstraint": SumConstraint,
"Surrogate": Surrogate,
"SurrogateMetric": BenchmarkMetric, # backward-compatiblity
# NOTE: SurrogateRunners -> SyntheticRunner on load due to complications
"SurrogateRunner": SyntheticRunner,
"SobolQMCNormalSampler": SobolQMCNormalSampler,
"SyntheticRunner": SyntheticRunner,
"SurrogateSpec": SurrogateSpec,
Expand Down
49 changes: 9 additions & 40 deletions ax/utils/testing/benchmark_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch
from ax.benchmark.benchmark_result import AggregatedBenchmarkResult, BenchmarkResult
from ax.benchmark.problems.surrogate import SurrogateBenchmarkProblem
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.surrogate import SurrogateRunner, SurrogateTestFunction
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
from ax.core.optimization_config import (
Expand Down Expand Up @@ -129,41 +128,7 @@ def get_soo_surrogate() -> BenchmarkProblem:
)


def get_soo_surrogate_legacy() -> SurrogateBenchmarkProblem:
experiment = get_branin_experiment(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)),
data=experiment.lookup_data(),
transforms=[],
)
runner = SurrogateRunner(
name="test",
outcome_names=["branin"],
get_surrogate_and_datasets=lambda: (surrogate, []),
)

observe_noise_sd = True
objective = Objective(
metric=BenchmarkMetric(
name="branin", lower_is_better=True, observe_noise_sd=observe_noise_sd
),
)
optimization_config = OptimizationConfig(objective=objective)

return SurrogateBenchmarkProblem(
name="test",
search_space=experiment.search_space,
optimization_config=optimization_config,
num_trials=6,
observe_noise_stds=observe_noise_sd,
optimal_value=0.0,
runner=runner,
)


def get_moo_surrogate() -> SurrogateBenchmarkProblem:
def get_moo_surrogate() -> BenchmarkProblem:
experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True)
surrogate = TorchModelBridge(
experiment=experiment,
Expand All @@ -173,11 +138,15 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem:
transforms=[],
)

runner = SurrogateRunner(
outcome_names = ["branin_a", "branin_b"]
test_function = SurrogateTestFunction(
name="test",
outcome_names=["branin_a", "branin_b"],
outcome_names=outcome_names,
get_surrogate_and_datasets=lambda: (surrogate, []),
)
runner = ParamBasedTestProblemRunner(
test_problem=test_function, outcome_names=outcome_names
)
observe_noise_sd = True
optimization_config = MultiObjectiveOptimizationConfig(
objective=MultiObjective(
Expand All @@ -199,7 +168,7 @@ def get_moo_surrogate() -> SurrogateBenchmarkProblem:
],
)
)
return SurrogateBenchmarkProblem(
return BenchmarkProblem(
name="test",
search_space=experiment.search_space,
optimization_config=optimization_config,
Expand Down

0 comments on commit 027cfba

Please sign in to comment.