Skip to content

Commit

Permalink
Make BenchmarkRunners into dataclasses (#2892)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2892

Context: D64110932 didn't work because of serialization, but now we don't serialize runners anymore.

This diff:
* Largely the same as D64110932, but also makes `SurrogateRunner` into a dataclass and makes its equality check stricter

Reviewed By: saitcakmak

Differential Revision: D64360228

fbshipit-source-id: 24729d042e07cc543db7215b83428321b502be04
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 17, 2024
1 parent 37e6f54 commit 849a7ec
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 181 deletions.
22 changes: 8 additions & 14 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping
from dataclasses import dataclass, field, InitVar
from math import sqrt
from typing import Any

import torch

from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.runner import Runner
Expand All @@ -26,6 +26,7 @@
from torch import Tensor


@dataclass(kw_only=True)
class BenchmarkRunner(Runner, ABC):
"""
A Runner that produces both observed and ground-truth values.
Expand All @@ -45,19 +46,12 @@ class BenchmarkRunner(Runner, ABC):
not over-engineer for that before such a use case arrives.
"""

def __init__(
self,
*,
outcome_names: list[str],
search_space_digest: SearchSpaceDigest | None = None,
) -> None:
"""
Args:
outcome_names: Outcome names, needed for going between tensors and
data in formats used by Ax.
search_space_digest: Used to extract target fidelity and task.
"""
self.outcome_names = outcome_names
outcome_names: list[str]
# pyre-fixme[8]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
target_fidelity_and_task: Mapping[str, float | int] = field(init=False)

def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
if search_space_digest is not None:
self.target_fidelity_and_task: dict[str, float | int] = {
search_space_digest.feature_names[i]: target
Expand Down
173 changes: 68 additions & 105 deletions ax/benchmark/runners/botorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from abc import ABC, abstractmethod
from collections.abc import Mapping
from dataclasses import dataclass
from dataclasses import asdict, dataclass, field
from typing import Any

import torch
Expand Down Expand Up @@ -53,77 +53,45 @@ def __eq__(self, other: Any) -> bool:
return self.__class__.__name__ == other.__class__.__name__


@dataclass(kw_only=True)
class SyntheticProblemRunner(BenchmarkRunner, ABC):
"""A Runner for evaluating synthetic problems, either BoTorch
`BaseTestProblem`s or Ax benchmarking `ParamBasedTestProblem`s.
Given a trial, the Runner will evaluate the problem noiselessly for each
arm in the trial, as well as return some metadata about the underlying
problem such as the noise_std.
"""
test_problem: BaseTestProblem | ParamBasedTestProblem
_is_constrained: bool
_test_problem_class: type[BaseTestProblem | ParamBasedTestProblem]
_test_problem_kwargs: dict[str, Any] | None

def __init__(
self,
*,
test_problem_class: type[BaseTestProblem | ParamBasedTestProblem],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: list[tuple[float, float]] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> None:
"""Initialize the test problem runner.
Args:
test_problem_class: A BoTorch `BaseTestProblem` class or Ax
`ParamBasedTestProblem` class.
test_problem_kwargs: The keyword arguments used for initializing the
test problem.
outcome_names: The names of the outcomes returned by the problem.
modified_bounds: The bounds that are used by the Ax search space
while optimizing the problem. If different from the bounds of the
test problem, we project the parameters into the test problem
bounds before evaluating the test problem.
For example, if the test problem is defined on [0, 1] but the Ax
search space is integers in [0, 10], an Ax parameter value of
5 will correspond to 0.5 while evaluating the test problem.
If modified bounds are not provided, the test problem will be
evaluated using the raw parameter values.
search_space_digest: Used to extract target fidelity and task.
"""

Args:
test_problem_class: A BoTorch `BaseTestProblem` class or Ax
`ParamBasedTestProblem` class.
test_problem_kwargs: The keyword arguments used for initializing the
test problem.
outcome_names: The names of the outcomes returned by the problem.
modified_bounds: The bounds that are used by the Ax search space
while optimizing the problem. If different from the bounds of the
test problem, we project the parameters into the test problem
bounds before evaluating the test problem.
For example, if the test problem is defined on [0, 1] but the Ax
search space is integers in [0, 10], an Ax parameter value of
5 will correspond to 0.5 while evaluating the test problem.
If modified bounds are not provided, the test problem will be
evaluated using the raw parameter values.
search_space_digest: Used to extract target fidelity and task.
"""
super().__init__(
outcome_names=outcome_names, search_space_digest=search_space_digest
)
self._test_problem_class = test_problem_class
self._test_problem_kwargs = test_problem_kwargs
self.test_problem = (
# pyre-fixme: Invalid class instantiation [45]: Cannot instantiate
# abstract class with abstract method `evaluate_true`.
test_problem_class(**test_problem_kwargs)
)
if isinstance(self.test_problem, BaseTestProblem):
self.test_problem = self.test_problem.to(dtype=torch.double)
# A `ConstrainedBaseTestProblem` is a type of `BaseTestProblem`; a
# `ParamBasedTestProblem` is never constrained.
self._is_constrained: bool = isinstance(
self.test_problem, ConstrainedBaseTestProblem
)
self._is_moo: bool = self.test_problem.num_objectives > 1
self._modified_bounds = modified_bounds
test_problem_class: type[BaseTestProblem | ParamBasedTestProblem]
test_problem_kwargs: dict[str, Any] = field(default_factory=dict)
modified_bounds: list[tuple[float, float]] | None = None
test_problem: BaseTestProblem | ParamBasedTestProblem = field(init=False)

@equality_typechecker
def __eq__(self, other: Base) -> bool:
if not isinstance(other, type(self)):
return False
@property
def _is_moo(self) -> bool:
return self.test_problem.num_objectives > 1

return (
self.test_problem.__class__.__name__
== other.test_problem.__class__.__name__
)
@property
def _is_constrained(self) -> bool:
return issubclass(self.test_problem_class, ConstrainedBaseTestProblem)

def get_noise_stds(self) -> None | float | dict[str, float]:
noise_std = self.test_problem.noise_std
Expand Down Expand Up @@ -163,6 +131,7 @@ def get_noise_stds(self) -> None | float | dict[str, float]:
return noise_std_dict


@dataclass(kw_only=True)
class BotorchTestProblemRunner(SyntheticProblemRunner):
"""
A `SyntheticProblemRunner` for BoTorch `BaseTestProblem`s.
Expand All @@ -184,25 +153,14 @@ class BotorchTestProblemRunner(SyntheticProblemRunner):
search_space_digest: Used to extract target fidelity and task.
"""

def __init__(
self,
*,
test_problem_class: type[BaseTestProblem],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: list[tuple[float, float]] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> None:
super().__init__(
test_problem_class=test_problem_class,
test_problem_kwargs=test_problem_kwargs,
outcome_names=outcome_names,
modified_bounds=modified_bounds,
search_space_digest=search_space_digest,
)
self.test_problem: BaseTestProblem = self.test_problem.to(dtype=torch.double)
self._is_constrained: bool = isinstance(
self.test_problem, ConstrainedBaseTestProblem
test_problem_class: type[BaseTestProblem]
test_problem: BaseTestProblem = field(init=False)

def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
super().__post_init__(search_space_digest=search_space_digest)
# pyre-fixme[45]: Can't instantiate abstract class `BaseTestProblem`.
self.test_problem = self.test_problem_class(**self.test_problem_kwargs).to(
torch.double
)

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
Expand All @@ -226,10 +184,10 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
dtype=torch.double,
)

if self._modified_bounds is not None:
if self.modified_bounds is not None:
# Normalize from modified bounds to unit cube.
unit_X = normalize(
X, torch.tensor(self._modified_bounds, dtype=torch.double).T
X, torch.tensor(self.modified_bounds, dtype=torch.double).T
)
# Unnormalize from unit cube to original problem bounds.
X = unnormalize(unit_X, self.test_problem.bounds)
Expand All @@ -249,37 +207,42 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:

return Y_true

@equality_typechecker
def __eq__(self, other: Base) -> bool:
"""
Compare equality by comparing dicts, except for `test_problem`.
Dataclasses are compared by comparing the results of calling asdict on
them. However, equality checks don't work as needed with BoTorch test
problems, e.g. Branin() == Branin() is False. To get around that, the
test problem is stripped from the dictionary. This doesn't make the
check less sensitive, as long as the problem has not been modified,
because the test problem class and keyword arguments will still be
compared.
"""
if not isinstance(other, type(self)):
return False
self_as_dict = asdict(self)
other_as_dict = asdict(other)
self_as_dict.pop("test_problem")
other_as_dict.pop("test_problem")
return self_as_dict == other_as_dict


@dataclass(kw_only=True)
class ParamBasedTestProblemRunner(SyntheticProblemRunner):
"""
A `SyntheticProblemRunner` for `ParamBasedTestProblem`s. See
`SyntheticProblemRunner` for more information.
"""

# This could easily be supported, but hasn't been hooked up
_is_constrained: bool = False
test_problem_class: type[ParamBasedTestProblem]
test_problem: ParamBasedTestProblem = field(init=False)

def __init__(
self,
*,
test_problem_class: type[ParamBasedTestProblem],
test_problem_kwargs: dict[str, Any],
outcome_names: list[str],
modified_bounds: list[tuple[float, float]] | None = None,
search_space_digest: SearchSpaceDigest | None = None,
) -> None:
if modified_bounds is not None:
raise NotImplementedError(
f"modified_bounds is not supported for {test_problem_class.__name__}"
)
super().__init__(
test_problem_class=test_problem_class,
test_problem_kwargs=test_problem_kwargs,
outcome_names=outcome_names,
modified_bounds=modified_bounds,
search_space_digest=search_space_digest,
)
self.test_problem: ParamBasedTestProblem = self.test_problem
def __post_init__(self, search_space_digest: SearchSpaceDigest | None) -> None:
super().__post_init__(search_space_digest=search_space_digest)
# pyre-fixme[45]: Can't instantiate abstract class `ParamBasedTestProblem`.
self.test_problem = self.test_problem_class(**self.test_problem_kwargs)

def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""Evaluates the test problem.
Expand Down
Loading

0 comments on commit 849a7ec

Please sign in to comment.