Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorganize runner tests following previous class renaming #2991

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ax/benchmark/benchmark_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import pandas as pd

from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -309,7 +309,7 @@ def create_problem_from_botorch(
Create a `BenchmarkProblem` from a BoTorch `BaseTestProblem`.

Uses specialized Metrics and Runners for benchmarking. The test problem's
result will be computed by the Runner, `BoTorchTestProblemRunner`, and
result will be computed by the Runner, `BenchmarkRunner`, and
retrieved by the Metric(s), which are `BenchmarkMetric`s.

Args:
Expand Down Expand Up @@ -378,7 +378,7 @@ def create_problem_from_botorch(
search_space=search_space,
optimization_config=optimization_config,
runner=BenchmarkRunner(
test_problem=BoTorchTestProblem(botorch_problem=test_problem),
test_function=BoTorchTestFunction(botorch_problem=test_problem),
outcome_names=outcome_names,
search_space_digest=extract_search_space_digest(
search_space=search_space,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy.typing as npt

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
Expand Down Expand Up @@ -48,15 +48,15 @@ class BenchmarkRunner(Runner):

Args:
outcome_names: The names of the outcomes returned by the problem.
test_problem: A ``ParamBasedTestProblem`` from which to generate
test_function: A ``BenchmarkTestFunction`` from which to generate
deterministic data before adding noise.
noise_std: The standard deviation of the noise added to the data. Can be
a list or dict to be per-metric.
search_space_digest: Used to extract target fidelity and task.
"""

outcome_names: list[str]
test_problem: ParamBasedTestProblem
test_function: BenchmarkTestFunction
noise_std: float | list[float] | dict[str, float] = 0.0
# pyre-fixme[16]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
Expand All @@ -77,7 +77,7 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
Returns:
An `m`-dim tensor of ground truth (noiseless) evaluations.
"""
return torch.atleast_1d(self.test_problem.evaluate_true(params=params))
return torch.atleast_1d(self.test_function.evaluate_true(params=params))

def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray:
"""
Expand Down
32 changes: 32 additions & 0 deletions ax/benchmark/benchmark_test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass

from ax.core.types import TParamValue
from torch import Tensor


@dataclass(kw_only=True)
class BenchmarkTestFunction(ABC):
"""
The basic Ax class for generating deterministic data to benchmark against.

(Noise - if desired - is added by the runner.)
"""

@abstractmethod
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Evaluate noiselessly.

Returns:
1d tensor of shape (num_outcomes,).
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,18 @@

# pyre-strict

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from itertools import islice

import torch
from ax.core.types import TParamValue
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from botorch.test_functions.synthetic import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.utils.transforms import normalize, unnormalize
from torch import Tensor


@dataclass(kw_only=True)
class ParamBasedTestProblem(ABC):
"""
The basic Ax class for generating deterministic data to benchmark against.

(Noise - if desired - is added by the runner.)
"""

@abstractmethod
def evaluate_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
Evaluate noiselessly.

Returns:
1d tensor of shape (num_outcomes,).
"""
...


@dataclass(kw_only=True)
class BoTorchTestProblem(ParamBasedTestProblem):
class BoTorchTestFunction(BenchmarkTestFunction):
"""
Class for generating data from a BoTorch ``BaseTestProblem``.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import dataclass

import torch
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.observation import ObservationFeatures
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
Expand All @@ -21,7 +21,7 @@


@dataclass(kw_only=True)
class SurrogateTestFunction(ParamBasedTestProblem):
class SurrogateTestFunction(BenchmarkTestFunction):
"""
Data-generating function for surrogate benchmark problems.

Expand Down
14 changes: 7 additions & 7 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
BenchmarkProblem,
get_soo_config_and_outcome_names,
)
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
Expand Down Expand Up @@ -113,7 +113,7 @@ def train_and_evaluate(


@dataclass(kw_only=True)
class PyTorchCNNTorchvisionParamBasedProblem(ParamBasedTestProblem):
class PyTorchCNNTorchvisionBenchmarkTestFunction(BenchmarkTestFunction):
name: str # The name of the dataset to load -- MNIST or FashionMNIST
device: torch.device = field(
default_factory=lambda: torch.device(
Expand Down Expand Up @@ -151,7 +151,7 @@ def __post_init__(self, train_loader: None, test_loader: None) -> None:
transform=transforms.ToTensor(),
)
# pyre-fixme: Undefined attribute [16]:
# `PyTorchCNNTorchvisionParamBasedProblem` has no attribute
# `PyTorchCNNTorchvisionBenchmarkTestFunction` has no attribute
# `train_loader`.
self.train_loader = DataLoader(train_set, num_workers=1)
# pyre-fixme
Expand All @@ -163,10 +163,10 @@ def evaluate_true(self, params: Mapping[str, int | float]) -> Tensor:
frac_correct = train_and_evaluate(
**params,
device=self.device,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# pyre-fixme[16]: `PyTorchCNNTorchvisionBenchmarkTestFunction` has no
# attribute `train_loader`.
train_loader=self.train_loader,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# pyre-fixme[16]: `PyTorchCNNTorchvisionBenchmarkTestFunction` has no
# attribute `test_loader`.
test_loader=self.test_loader,
)
Expand Down Expand Up @@ -215,7 +215,7 @@ def get_pytorch_cnn_torchvision_benchmark_problem(
objective_name="accuracy",
)
runner = BenchmarkRunner(
test_problem=PyTorchCNNTorchvisionParamBasedProblem(name=name),
test_function=PyTorchCNNTorchvisionBenchmarkTestFunction(name=name),
outcome_names=outcome_names,
)
return BenchmarkProblem(
Expand Down
8 changes: 4 additions & 4 deletions ax/benchmark/problems/synthetic/discretized/mixed_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from ax.benchmark.benchmark_metric import BenchmarkMetric

from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import BoTorchTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ParameterType, RangeParameter
Expand All @@ -47,7 +47,7 @@ def _get_problem_from_common_inputs(

Args:
bounds: The parameter bounds. These will be passed to
`BotorchTestProblemRunner` as `modified_bounds`, and the parameters
`BotorchTestFunction` as `modified_bounds`, and the parameters
will be renormalized from these bounds to the bounds of the original
problem. For example, if `bounds` are [(0, 3)] and the test
problem's original bounds are [(0, 2)], then the original problem
Expand Down Expand Up @@ -103,7 +103,7 @@ def _get_problem_from_common_inputs(
else:
test_problem = test_problem_class(dim=dim, bounds=test_problem_bounds)
runner = BenchmarkRunner(
test_problem=BoTorchTestProblem(
test_function=BoTorchTestFunction(
botorch_problem=test_problem, modified_bounds=bounds
),
outcome_names=[metric_name],
Expand Down
8 changes: 4 additions & 4 deletions ax/benchmark/problems/synthetic/hss/jenatton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch
from ax.benchmark.benchmark_metric import BenchmarkMetric
from ax.benchmark.benchmark_problem import BenchmarkProblem
from ax.benchmark.runners.base import BenchmarkRunner
from ax.benchmark.runners.botorch_test import ParamBasedTestProblem
from ax.benchmark.benchmark_runner import BenchmarkRunner
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.objective import Objective
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
Expand Down Expand Up @@ -50,7 +50,7 @@ def jenatton_test_function(


@dataclass(kw_only=True)
class Jenatton(ParamBasedTestProblem):
class Jenatton(BenchmarkTestFunction):
"""Jenatton test function for hierarchical search spaces."""

# pyre-fixme[14]: Inconsistent override
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_jenatton_benchmark_problem(
search_space=search_space,
optimization_config=optimization_config,
runner=BenchmarkRunner(
test_problem=Jenatton(), outcome_names=[name], noise_std=noise_std
test_function=Jenatton(), outcome_names=[name], noise_std=noise_std
),
num_trials=num_trials,
observe_noise_stds=observe_noise_sd,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


import torch
from ax.benchmark.benchmark_test_functions.botorch_test import BoTorchTestFunction
from ax.utils.common.testutils import TestCase
from botorch.test_functions.multi_objective import BraninCurrin
from botorch.test_functions.synthetic import ConstrainedHartmann, Hartmann


class TestBoTorchTestFunction(TestCase):
def setUp(self) -> None:
super().setUp()
botorch_base_test_functions = {
"base Hartmann": Hartmann(dim=6),
"negated Hartmann": Hartmann(dim=6, negate=True),
"constrained Hartmann": ConstrainedHartmann(dim=6),
"negated constrained Hartmann": ConstrainedHartmann(dim=6, negate=True),
}
self.botorch_test_problems = {
k: BoTorchTestFunction(botorch_problem=v)
for k, v in botorch_base_test_functions.items()
}

def test_negation(self) -> None:
params = {f"x{i}": 0.5 for i in range(6)}
evaluate_true_results = {
k: v.evaluate_true(params) for k, v in self.botorch_test_problems.items()
}
self.assertEqual(
evaluate_true_results["base Hartmann"],
evaluate_true_results["constrained Hartmann"][0],
)
self.assertEqual(
evaluate_true_results["base Hartmann"],
-evaluate_true_results["negated Hartmann"],
)
self.assertEqual(
evaluate_true_results["negated Hartmann"],
evaluate_true_results["negated constrained Hartmann"][0],
)
self.assertEqual(
evaluate_true_results["constrained Hartmann"][1],
evaluate_true_results["negated constrained Hartmann"][1],
)

def test_raises_for_botorch_attrs(self) -> None:
msg = "noise should be set on the `BenchmarkRunner`, not the test function."
with self.assertRaisesRegex(ValueError, msg):
BoTorchTestFunction(botorch_problem=Hartmann(dim=6, noise_std=0.1))
with self.assertRaisesRegex(ValueError, msg):
BoTorchTestFunction(
botorch_problem=ConstrainedHartmann(dim=6, constraint_noise_std=0.1)
)

def test_tensor_shapes(self) -> None:
params = {f"x{i}": 0.5 for i in range(6)}
evaluate_true_results = {
k: v.evaluate_true(params) for k, v in self.botorch_test_problems.items()
}
evaluate_true_results["BraninCurrin"] = BoTorchTestFunction(
botorch_problem=BraninCurrin()
).evaluate_true(params)
expected_len = {
"base Hartmann": 1,
"constrained Hartmann": 2,
"negated Hartmann": 1,
"negated constrained Hartmann": 2,
"BraninCurrin": 2,
}
for name, result in evaluate_true_results.items():
with self.subTest(name=name):
self.assertEqual(result.dtype, torch.double)
self.assertEqual(result.shape, torch.Size([expected_len[name]]))
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from unittest.mock import MagicMock, patch

import torch
from ax.benchmark.runners.surrogate import SurrogateTestFunction
from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction
from ax.modelbridge.torch import TorchModelBridge
from ax.utils.common.testutils import TestCase
from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function
Expand Down
Loading