Skip to content

Commit

Permalink
Get rid of ParamBasedTestProblemRunner (#2964)
Browse files Browse the repository at this point in the history
Summary:

See previous diff for context.

**NOTE TO REVIEWERS**: The only interesting changes are in `runners/base.py` and `runners/botorch_test.py`, and updating registries. And it is basically a cut-and-paste job, with `BenchmarkRunner` now having the functionality of the old `ParamBasedTestProblemRunner`. Sorry this is so large! This can't be easily broken up without creating a circular import issue. I can find a way to make smaller diffs if need be, but it would require more refactoring.

This diff:
* Adds functionality from `runners.botorch_test.ParamBasedTestProblemRunner` into `runners.base.BenchmarkRunner`
* Removes `ParamBasedTestProblemRunner`
* Updates call sites
* Removes some `assert_is_instance` checks

Reviewed By: Balandat

Differential Revision: D64865707
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 29, 2024
1 parent a38e814 commit cef39ef
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 120 deletions.
7 changes: 2 additions & 5 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@

from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -380,7 +377,7 @@ def create_problem_from_botorch(
name=name,
search_space=search_space,
optimization_config=optimization_config,
runner=ParamBasedTestProblemRunner(
runner=BenchmarkRunner(
test_problem=BoTorchTestProblem(botorch_problem=test_problem),
outcome_names=outcome_names,
search_space_digest=extract_search_space_digest(
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
BenchmarkProblem,
get_soo_config_and_outcome_names,
)
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -217,7 +215,7 @@ def get_pytorch_cnn_torchvision_benchmark_problem(
observe_noise_sd=False,
objective_name="accuracy",
)
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=PyTorchCNNTorchvisionParamBasedProblem(name=name),
outcome_names=outcome_names,
)
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/problems/synthetic/discretized/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
Expand Down Expand Up @@ -104,7 +102,7 @@ def _get_problem_from_common_inputs(
test_problem = test_problem_class(dim=dim)
else:
test_problem = test_problem_class(dim=dim, bounds=test_problem_bounds)
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(
botorch_problem=test_problem, modified_bounds=bounds
),
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
import torch
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
Expand Down Expand Up @@ -120,7 +118,7 @@ def get_jenatton_benchmark_problem(
name=name,
search_space=search_space,
optimization_config=optimization_config,
runner=ParamBasedTestProblemRunner(
runner=BenchmarkRunner(
test_problem=Jenatton(), outcome_names=[name], noise_std=noise_std
),
num_trials=num_trials,
Expand Down
40 changes: 28 additions & 12 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field, InitVar
from math import sqrt
from typing import Any

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
Expand All @@ -27,7 +27,7 @@


@dataclass(kw_only=True)
class BenchmarkRunner(Runner, ABC):
class BenchmarkRunner(Runner):
"""
A Runner that produces both observed and ground-truth values.
Expand All @@ -44,9 +44,19 @@ class BenchmarkRunner(Runner, ABC):
- If they are not deterministc, they are not supported. It is not
conceptually clear how to benchmark such problems, so we decided to
not over-engineer for that before such a use case arrives.
Args:
outcome_names: The names of the outcomes returned by the problem.
test_problem: A ``ParamBasedTestProblem`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list or dict to be per-metric.
search_space_digest: Used to extract target fidelity and task.
"""

outcome_names: list[str]
test_problem: ParamBasedTestProblem
noise_std: float | list[float] | dict[str, float] = 0.0
# pyre-fixme[16]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
target_fidelity_and_task: Mapping[str, float | int] = field(init=False)
Expand All @@ -61,12 +71,12 @@ def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
self.target_fidelity_and_task = {}

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Return the ground truth values for a given arm.
"""Evaluates the test problem.
Synthetic noise is added as part of the Runner's `run()` method.
Returns:
An `m`-dim tensor of ground truth (noiseless) evaluations.
"""
...
return torch.atleast_1d(self.test_problem.evaluate_true(params=params))

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
Expand All @@ -83,13 +93,19 @@ def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
params = {**parameters, **self.target_fidelity_and_task}
return self.get_Y_true(params=params).numpy()

@abstractmethod
def get_noise_stds(self) -> dict[str, float]:
"""
Return the standard errors for the synthetic noise to be applied to the
observed values.
"""
...
noise_std = self.noise_std
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
elif isinstance(noise_std, dict):
if not set(noise_std.keys()) == set(self.outcome_names):
raise ValueError(
"Noise std must have keys equal to outcome names if given as "
"a dict."
)
return noise_std
# list of floats
return dict(zip(self.outcome_names, noise_std, strict=True))

def run(self, trial: BaseTrial) -> dict[str, Any]:
"""Run the trial by evaluating its parameterization(s).
Expand Down
46 changes: 0 additions & 46 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from itertools import islice

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.core.types import TParamValue
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
Expand Down Expand Up @@ -91,48 +90,3 @@ def evaluate_true(self, params: Mapping[str, float | int]) -> torch.Tensor:
constraints = self.botorch_problem.evaluate_slack_true(x).view(-1)
return torch.cat([objectives, constraints], dim=-1)
return objectives


@dataclass(kw_only=True)
class ParamBasedTestProblemRunner(BenchmarkRunner):
"""
A Runner for evaluating `ParamBasedTestProblem`s.
Given a trial, the Runner will use its `test_problem` to evaluate the
problem noiselessly for each arm in the trial, and then add noise as
specified by the `noise_std`. It will return
metadata including the outcome names and values of metrics.
Args:
outcome_names: The names of the outcomes returned by the problem.
search_space_digest: Used to extract target fidelity and task.
test_problem: A ``ParamBasedTestProblem`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list to be per-metric.
"""

test_problem: ParamBasedTestProblem
noise_std: float | list[float] | dict[str, float] = 0.0

def get_noise_stds(self) -> dict[str, float]:
noise_std = self.noise_std
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
elif isinstance(noise_std, dict):
if not set(noise_std.keys()) == set(self.outcome_names):
raise ValueError(
"Noise std must have keys equal to outcome names if given as "
"a dict."
)
return noise_std
# list of floats
return dict(zip(self.outcome_names, noise_std, strict=True))

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""Evaluates the test problem.
Returns:
An `m`-dim tensor of ground truth (noiseless) evaluations.
"""
return torch.atleast_1d(self.test_problem.evaluate_true(params=params))
11 changes: 2 additions & 9 deletions ax/benchmark/tests/problems/synthetic/hss/test_jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
get_jenatton_benchmark_problem,
jenatton_test_function,
)
from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner
from ax.core.arm import Arm
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -106,10 +105,7 @@ def test_create_problem(self) -> None:
self.assertEqual(metric.name, "Jenatton")
self.assertTrue(objective.minimize)
self.assertTrue(metric.lower_is_better)
self.assertEqual(
assert_is_instance(problem.runner, ParamBasedTestProblemRunner).noise_std,
0.0,
)
self.assertEqual(problem.runner.noise_std, 0.0)
self.assertFalse(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd)

problem = get_jenatton_benchmark_problem(
Expand All @@ -118,10 +114,7 @@ def test_create_problem(self) -> None:
objective = problem.optimization_config.objective
metric = objective.metric
self.assertTrue(metric.lower_is_better)
self.assertEqual(
assert_is_instance(problem.runner, ParamBasedTestProblemRunner).noise_std,
0.1,
)
self.assertEqual(problem.runner.noise_std, 0.1)
self.assertTrue(assert_is_instance(metric, BenchmarkMetric).observe_noise_sd)

def test_fetch_trial_data(self) -> None:
Expand Down
9 changes: 3 additions & 6 deletions ax/benchmark/tests/problems/test_mixed_integer_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
get_discrete_hartmann,
get_discrete_rosenbrock,
)
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.arm import Arm
from ax.core.parameter import ParameterType
from ax.core.trial import Trial
Expand All @@ -37,7 +34,7 @@ def test_problems(self) -> None:
name = problem_cls.__name__
problem = constructor()
self.assertEqual(f"Discrete {name}", problem.name)
runner = assert_is_instance(problem.runner, ParamBasedTestProblemRunner)
runner = problem.runner
test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem)
botorch_problem = test_problem.botorch_problem
self.assertIsInstance(botorch_problem, problem_cls)
Expand Down Expand Up @@ -99,7 +96,7 @@ def test_problems(self) -> None:
]

for problem, params, expected_arg in cases:
runner = assert_is_instance(problem.runner, ParamBasedTestProblemRunner)
runner = problem.runner
test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem)
trial = Trial(experiment=MagicMock())
# pyre-fixme: Incompatible parameter type [6]: In call
Expand Down
14 changes: 6 additions & 8 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@

import torch
from ax.benchmark.problems.synthetic.hss.jenatton import get_jenatton_benchmark_problem
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
Expand Down Expand Up @@ -154,7 +152,7 @@ def test_synthetic_runner(self) -> None:
outcome_names = ["branin"]

# Set up runner
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=test_problem,
outcome_names=outcome_names,
noise_std=noise_std,
Expand Down Expand Up @@ -293,15 +291,15 @@ def test_synthetic_runner(self) -> None:
with self.assertRaisesRegex(
UnsupportedError, "serialize_init_args is not a supported method"
):
ParamBasedTestProblemRunner.serialize_init_args(obj=runner)
BenchmarkRunner.serialize_init_args(obj=runner)
with self.assertRaisesRegex(
UnsupportedError, "deserialize_init_args is not a supported method"
):
ParamBasedTestProblemRunner.deserialize_init_args({})
BenchmarkRunner.deserialize_init_args({})

def test_botorch_test_problem_runner_heterogeneous_noise(self) -> None:
for noise_std in [[0.1, 0.05], {"objective": 0.1, "constraint": 0.05}]:
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(
botorch_problem=ConstrainedHartmann(dim=6)
),
Expand Down
10 changes: 4 additions & 6 deletions ax/benchmark/tests/test_benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric

from ax.benchmark.benchmark_problem import BenchmarkProblem, create_problem_from_botorch
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.objective import MultiObjective, Objective
from ax.core.optimization_config import (
MultiObjectiveOptimizationConfig,
Expand Down Expand Up @@ -54,7 +52,7 @@ def test_inference_value_not_implemented(self) -> None:
for name in ["Branin", "Currin"]
]
optimization_config = OptimizationConfig(objective=objectives[0])
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(botorch_problem=Branin()),
outcome_names=["foo"],
)
Expand Down Expand Up @@ -215,7 +213,7 @@ def _test_constrained_from_botorch(
observe_noise_sd=observe_noise_sd,
noise_std=noise_std,
)
runner = assert_is_instance(ax_problem.runner, ParamBasedTestProblemRunner)
runner = ax_problem.runner
test_problem = assert_is_instance(runner.test_problem, BoTorchTestProblem)
botorch_problem = assert_is_instance(
test_problem.botorch_problem, ConstrainedBaseTestProblem
Expand Down
4 changes: 1 addition & 3 deletions ax/storage/tests/test_registry_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# pyre-strict

from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.runners.botorch_test import ParamBasedTestProblemRunner
from ax.metrics.branin import BraninMetric
from ax.runners.synthetic import SyntheticRunner
from ax.storage.registry_bundle import RegistryBundle
Expand All @@ -26,7 +25,7 @@ def test_from_registry_bundles(self) -> None:

right = RegistryBundle(
metric_clss={BenchmarkMetric: None},
runner_clss={ParamBasedTestProblemRunner: None},
runner_clss={SyntheticRunner: None},
json_encoder_registry={},
json_class_encoder_registry={},
json_decoder_registry={},
Expand All @@ -41,4 +40,3 @@ def test_from_registry_bundles(self) -> None:
self.assertIn(BraninMetric, combined.encoder_registry)
self.assertIn(SyntheticRunner, combined.encoder_registry)
self.assertIn(BenchmarkMetric, combined.encoder_registry)
self.assertIn(ParamBasedTestProblemRunner, combined.encoder_registry)
Loading

0 comments on commit cef39ef

Please sign in to comment.