Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get rid of ParamBasedTestProblemRunner #2964

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@

from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -380,7 +377,7 @@ def create_problem_from_botorch(
name=name,
search_space=search_space,
optimization_config=optimization_config,
runner=ParamBasedTestProblemRunner(
runner=BenchmarkRunner(
test_problem=BoTorchTestProblem(botorch_problem=test_problem),
outcome_names=outcome_names,
search_space_digest=extract_search_space_digest(
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
BenchmarkProblem,
get_soo_config_and_outcome_names,
)
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -216,7 +214,7 @@ def get_pytorch_cnn_torchvision_benchmark_problem(
observe_noise_sd=False,
objective_name="accuracy",
)
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=PyTorchCNNTorchvisionParamBasedProblem(name=name),
outcome_names=outcome_names,
)
Expand Down
33 changes: 0 additions & 33 deletions ax/benchmark/problems/surrogate.py

This file was deleted.

8 changes: 3 additions & 5 deletions ax/benchmark/problems/synthetic/discretized/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.botorch_test import (
BoTorchTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
Expand Down Expand Up @@ -104,7 +102,7 @@ def _get_problem_from_common_inputs(
test_problem = test_problem_class(dim=dim)
else:
test_problem = test_problem_class(dim=dim, bounds=test_problem_bounds)
runner = ParamBasedTestProblemRunner(
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(
botorch_problem=test_problem, modified_bounds=bounds
),
Expand Down
8 changes: 3 additions & 5 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
import torch
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.botorch_test import (
ParamBasedTestProblem,
ParamBasedTestProblemRunner,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
Expand Down Expand Up @@ -120,7 +118,7 @@ def get_jenatton_benchmark_problem(
name=name,
search_space=search_space,
optimization_config=optimization_config,
runner=ParamBasedTestProblemRunner(
runner=BenchmarkRunner(
test_problem=Jenatton(), outcome_names=[name], noise_std=noise_std
),
num_trials=num_trials,
Expand Down
40 changes: 28 additions & 12 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field, InitVar
from math import sqrt
Expand All @@ -14,6 +13,7 @@
import numpy.typing as npt

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
Expand All @@ -28,7 +28,7 @@


@dataclass(kw_only=True)
class BenchmarkRunner(Runner, ABC):
class BenchmarkRunner(Runner):
"""
A Runner that produces both observed and ground-truth values.

Expand All @@ -45,9 +45,19 @@ class BenchmarkRunner(Runner, ABC):
- If they are not deterministc, they are not supported. It is not
conceptually clear how to benchmark such problems, so we decided to
not over-engineer for that before such a use case arrives.

Args:
outcome_names: The names of the outcomes returned by the problem.
test_problem: A ``ParamBasedTestProblem`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list or dict to be per-metric.
search_space_digest: Used to extract target fidelity and task.
"""

outcome_names: list[str]
test_problem: ParamBasedTestProblem
noise_std: float | list[float] | dict[str, float] = 0.0
# pyre-fixme[16]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
target_fidelity_and_task: Mapping[str, float | int] = field(init=False)
Expand All @@ -62,12 +72,12 @@ def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
self.target_fidelity_and_task = {}

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Return the ground truth values for a given arm.
"""Evaluates the test problem.

Synthetic noise is added as part of the Runner's `run()` method.
Returns:
An `m`-dim tensor of ground truth (noiseless) evaluations.
"""
...
return torch.atleast_1d(self.test_problem.evaluate_true(params=params))

def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray:
"""
Expand All @@ -83,13 +93,19 @@ def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray:
params = {**parameters, **self.target_fidelity_and_task}
return self.get_Y_true(params=params).numpy()

@abstractmethod
def get_noise_stds(self) -> dict[str, float]:
"""
Return the standard errors for the synthetic noise to be applied to the
observed values.
"""
...
noise_std = self.noise_std
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
elif isinstance(noise_std, dict):
if not set(noise_std.keys()) == set(self.outcome_names):
raise ValueError(
"Noise std must have keys equal to outcome names if given as "
"a dict."
)
return noise_std
# list of floats
return dict(zip(self.outcome_names, noise_std, strict=True))

def run(self, trial: BaseTrial) -> dict[str, Any]:
"""Run the trial by evaluating its parameterization(s).
Expand Down
46 changes: 0 additions & 46 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from itertools import islice

import torch
from ax.benchmark.runners.base import BenchmarkRunner
from ax.core.types import TParamValue
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
Expand Down Expand Up @@ -91,48 +90,3 @@ def evaluate_true(self, params: Mapping[str, float | int]) -> torch.Tensor:
constraints = self.botorch_problem.evaluate_slack_true(x).view(-1)
return torch.cat([objectives, constraints], dim=-1)
return objectives


@dataclass(kw_only=True)
class ParamBasedTestProblemRunner(BenchmarkRunner):
"""
A Runner for evaluating `ParamBasedTestProblem`s.

Given a trial, the Runner will use its `test_problem` to evaluate the
problem noiselessly for each arm in the trial, and then add noise as
specified by the `noise_std`. It will return
metadata including the outcome names and values of metrics.

Args:
outcome_names: The names of the outcomes returned by the problem.
search_space_digest: Used to extract target fidelity and task.
test_problem: A ``ParamBasedTestProblem`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list to be per-metric.
"""

test_problem: ParamBasedTestProblem
noise_std: float | list[float] | dict[str, float] = 0.0

def get_noise_stds(self) -> dict[str, float]:
noise_std = self.noise_std
if isinstance(noise_std, float):
return {name: noise_std for name in self.outcome_names}
elif isinstance(noise_std, dict):
if not set(noise_std.keys()) == set(self.outcome_names):
raise ValueError(
"Noise std must have keys equal to outcome names if given as "
"a dict."
)
return noise_std
# list of floats
return dict(zip(self.outcome_names, noise_std, strict=True))

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""Evaluates the test problem.

Returns:
An `m`-dim tensor of ground truth (noiseless) evaluations.
"""
return torch.atleast_1d(self.test_problem.evaluate_true(params=params))
Loading