Skip to content

Commit

Permalink
Fix docstring in Ax SyntheticFunction._f, Pyre fix (#2329)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2329

- Fix docstring in Ax SyntheticFunction._f and address some Pyre issues. This was raised in #1545 , which is pretty clear
- Add some type annotations
- Add override for safety in refactoring

Reviewed By: Balandat

Differential Revision: D55821434

fbshipit-source-id: ebe7650181d5a6adc4b3e5b566e2f4861499548c
  • Loading branch information
esantorella authored and facebook-github-bot committed Apr 6, 2024
1 parent 60330a8 commit b7e1fe2
Showing 1 changed file with 34 additions and 32 deletions.
66 changes: 34 additions & 32 deletions ax/utils/measurement/synthetic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from ax.utils.common.docutils import copy_doc
from ax.utils.common.typeutils import checked_cast, not_none
from botorch.test_functions import synthetic as botorch_synthetic
from pyre_extensions import override


# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
def informative_failure_on_none(func: Callable) -> Any:
def informative_failure_on_none(func: Callable) -> Callable:
# pyre-fixme[3]: Return annotation cannot be `Any`.
# pyre-fixme[2]: Parameter must be annotated.
def function_wrapper(*args, **kwargs) -> Any:
Expand All @@ -34,8 +34,7 @@ def function_wrapper(*args, **kwargs) -> Any:

class SyntheticFunction(ABC):

# pyre-fixme[4]: Attribute must be annotated.
_required_dimensionality = None
_required_dimensionality: Optional[int] = None
# pyre-fixme[4]: Attribute must be annotated.
_domain = None
# pyre-fixme[4]: Attribute must be annotated.
Expand Down Expand Up @@ -165,17 +164,16 @@ def fmax(self) -> float:
"""Value at global minimum(s)."""
return self._fmax

@classmethod
@abstractmethod
def _f(self, X: np.ndarray) -> float:
"""Implementation of the synthetic function. Must be implemented in subclass.
Args:
X (numpy.ndarray): an n by d array, where n represents the number
of observations and d is the dimensionality of the inputs.
X: A one-dimensional array with `d` elements, where d is the
dimensionality of the inputs.
Returns:
numpy.ndarray: an n-dimensional array.
float: Function value.
"""
...

Expand All @@ -185,17 +183,18 @@ def __init__(
self, botorch_synthetic_function: botorch_synthetic.SyntheticTestFunction
) -> None:
self._botorch_function = botorch_synthetic_function
# pyre-fixme[4]: Attribute must be annotated.
self._required_dimensionality = self._botorch_function.dim
# pyre-fixme[4]: Attribute must be annotated.
self._domain = self._botorch_function._bounds
# pyre-fixme[4]: Attribute must be annotated.
self._fmin = self._botorch_function._optimal_value
self._required_dimensionality: int = self._botorch_function.dim
self._domain: Optional[List[Tuple[float, float]]] = (
self._botorch_function._bounds
)
self._fmin: float = self._botorch_function._optimal_value

@override
@property
def name(self) -> str:
return f"{self.__class__.__name__}_{self._botorch_function.__class__.__name__}"

@override
def _f(self, X: np.ndarray) -> float:
# TODO: support batch evaluation
return float(self._botorch_function(X=torch.from_numpy(X)).item())
Expand All @@ -212,25 +211,20 @@ class Hartmann6(SyntheticFunction):
"""Hartmann6 function (6-dimensional with 1 global minimum)."""

_required_dimensionality = 6
# pyre-fixme[4]: Attribute must be annotated.
_domain = [(0, 1) for i in range(6)]
_domain: List[Tuple[int, int]] = [(0, 1) for i in range(6)]
_minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573)]
# pyre-fixme[4]: Attribute must be annotated.
_fmin = -3.32237
_fmin: float = -3.32237
_fmax = 0.0
# pyre-fixme[4]: Attribute must be annotated.
_alpha = np.array([1.0, 1.2, 3.0, 3.2])
# pyre-fixme[4]: Attribute must be annotated.
_A = np.array(
_alpha: np.ndarray = np.array([1.0, 1.2, 3.0, 3.2])
_A: np.ndarray = np.array(
[
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14],
]
)
# pyre-fixme[4]: Attribute must be annotated.
_P = 10 ** (-4) * np.array(
_P: np.ndarray = 10 ** (-4) * np.array(
[
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
Expand All @@ -239,6 +233,7 @@ class Hartmann6(SyntheticFunction):
]
)

@override
@copy_doc(SyntheticFunction._f)
def _f(self, X: np.ndarray) -> float:
y = 0.0
Expand All @@ -254,15 +249,14 @@ class Aug_Hartmann6(Hartmann6):
"""Augmented Hartmann6 function (7-dimensional with 1 global minimum)."""

_required_dimensionality = 7
# pyre-fixme[4]: Attribute must be annotated.
_domain = [(0, 1) for i in range(7)]
_domain: List[Tuple[int, int]] = [(0, 1) for i in range(7)]
# pyre-fixme[15]: `_minimums` overrides attribute defined in `Hartmann6`
# inconsistently.
_minimums = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573, 1.0)]
# pyre-fixme[4]: Attribute must be annotated.
_fmin = -3.32237
_fmin: float = -3.32237
_fmax = 0.0

@override
@copy_doc(SyntheticFunction._f)
def _f(self, X: np.ndarray) -> float:
y = 0.0
Expand All @@ -283,11 +277,15 @@ class Branin(SyntheticFunction):

_required_dimensionality = 2
_domain = [(-5, 10), (0, 15)]
# pyre-fixme[4]: Attribute must be annotated.
_minimums = [(-np.pi, 12.275), (np.pi, 2.275), (9.42478, 2.475)]
_minimums: List[Tuple[float, float]] = [
(-np.pi, 12.275),
(np.pi, 2.275),
(9.42478, 2.475),
]
_fmin = 0.397887
_fmax = 308.129

@override
@copy_doc(SyntheticFunction._f)
def _f(self, X: np.ndarray) -> float:
x_1 = X[0]
Expand All @@ -304,11 +302,15 @@ class Aug_Branin(SyntheticFunction):

_required_dimensionality = 3
_domain = [(-5, 10), (0, 15), (0, 1)]
# pyre-fixme[4]: Attribute must be annotated.
_minimums = [(-np.pi, 12.275, 1), (np.pi, 2.275, 1), (9.42478, 2.475, 1)]
_minimums: List[Tuple[float, float, int]] = [
(-np.pi, 12.275, 1),
(np.pi, 2.275, 1),
(9.42478, 2.475, 1),
]
_fmin = 0.397887
_fmax = 308.129

@override
@copy_doc(SyntheticFunction._f)
def _f(self, X: np.ndarray) -> float:
x_1 = X[0]
Expand Down

0 comments on commit b7e1fe2

Please sign in to comment.