Skip to content

Commit

Permalink
Fix np.ndarray type annotation
Browse files Browse the repository at this point in the history
Differential Revision: D65146743
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 29, 2024
1 parent fed5c55 commit c8a0f38
Show file tree
Hide file tree
Showing 71 changed files with 323 additions and 594 deletions.
7 changes: 3 additions & 4 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from time import monotonic, time

import numpy as np
import numpy.typing as npt

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import BenchmarkProblem
Expand All @@ -41,12 +42,10 @@


def compute_score_trace(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_trace: np.ndarray,
optimization_trace: npt.NDArray,
num_baseline_trials: int,
problem: BenchmarkProblem,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
) -> npt.NDArray:
"""Computes a score trace from the optimization trace."""

# Use the first GenerationStep's best found point as baseline. Sometimes (ex. in
Expand Down
16 changes: 6 additions & 10 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from dataclasses import dataclass

import numpy as np
import numpy.typing as npt
from ax.core.experiment import Experiment
from ax.utils.common.base import Base
from numpy import nanmean, nanquantile, ndarray
Expand Down Expand Up @@ -78,14 +79,10 @@ class BenchmarkResult(Base):
name: str
seed: int

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
oracle_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
inference_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
optimization_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
score_trace: ndarray
oracle_trace: npt.NDArray
inference_trace: npt.NDArray
optimization_trace: npt.NDArray
score_trace: npt.NDArray

fit_time: float
gen_time: float
Expand Down Expand Up @@ -160,8 +157,7 @@ def from_benchmark_results(


def _get_stats(
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
step_data: Iterable[np.ndarray],
step_data: Iterable[npt.NDArray],
percentiles: list[float],
) -> dict[str, list[float]]:
quantiles = []
Expand Down
1 change: 0 additions & 1 deletion ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def train_and_evaluate(
total += labels.size(0)
correct += (predicted == labels).sum().item()

# pyre-fixme[7]: Expected `Tensor` but got `float`.
return correct / total


Expand Down
5 changes: 3 additions & 2 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from math import sqrt
from typing import Any

import numpy.typing as npt

import torch
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
Expand Down Expand Up @@ -68,8 +70,7 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
...

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray:
"""
Evaluate oracle metric values at a parameterization. In the base class,
oracle values are underlying noiseless function values evaluated at the
Expand Down
2 changes: 0 additions & 2 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@ def test_synthetic_runner(self) -> None:
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"branin": [4.2]}, None),
Expand Down
1 change: 0 additions & 1 deletion ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,6 @@ def normalized_arm_weights(
weights = np.array(self.weights)
if trunc_digits is not None:
atomic_weight = 10**-trunc_digits
# pyre-fixme[16]: `float` has no attribute `astype`.
int_weights = (
(total / atomic_weight) * (weights / np.sum(weights))
).astype(int)
Expand Down
1 change: 0 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
}


# pyre-fixme[13]: Attribute `_search_space` is never initialized.
class Experiment(Base):
"""Base class for defining an experiment."""

Expand Down
8 changes: 3 additions & 5 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ax.core.experiment as experiment
import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus
Expand Down Expand Up @@ -187,13 +188,10 @@ class ObservationData(Base):
"""

def __init__(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
self,
metric_names: list[str],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
means: np.ndarray,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
covariance: np.ndarray,
means: npt.NDArray,
covariance: npt.NDArray,
) -> None:
k = len(metric_names)
if means.shape != (k,):
Expand Down
1 change: 0 additions & 1 deletion ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def dependents(self) -> dict[TParamValue, list[str]]:

# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
def clone(self) -> Parameter:
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
pass

@property
Expand Down
7 changes: 3 additions & 4 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Sequence

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax import core
from ax.core.arm import Arm
Expand Down Expand Up @@ -1103,10 +1104,8 @@ class RobustSearchSpaceDigest:
Only relevant if paired with a `distribution_sampler`.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_param_perturbations: Callable[[], np.ndarray] | None = None
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_environmental: Callable[[], np.ndarray] | None = None
sample_param_perturbations: Callable[[], npt.NDArray] | None = None
sample_environmental: Callable[[], npt.NDArray] | None = None
environmental_variables: list[str] = field(default_factory=list)
multiplicative: bool = False

Expand Down
8 changes: 3 additions & 5 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import NamedTuple

import numpy as np
import numpy.typing as npt
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
Expand Down Expand Up @@ -129,12 +130,9 @@ def _get_missing_arm_trial_pairs(data: Data, metric_name: str) -> set[TArmTrial]


def best_feasible_objective(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_config: OptimizationConfig,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
values: dict[str, np.ndarray],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
values: dict[str, npt.NDArray],
) -> npt.NDArray:
"""Compute the best feasible objective value found by each iteration.
Args:
Expand Down
10 changes: 4 additions & 6 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from logging import Logger

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
Expand Down Expand Up @@ -54,12 +55,9 @@ class EarlyStoppingTrainingData:
which data come from the same arm.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
X: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Y: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Yvar: np.ndarray
X: npt.NDArray
Y: npt.NDArray
Yvar: npt.NDArray
arm_names: list[str | None]


Expand Down
10 changes: 4 additions & 6 deletions ax/metrics/branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,24 @@
# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import aug_branin, branin


class BraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
x1, x2 = x
return checked_cast(float, branin(x1=x1, x2=x2))


class NegativeBraninMetric(BraninMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
fpos = super().f(x)
return -fpos


class AugmentedBraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, aug_branin(x))
7 changes: 3 additions & 4 deletions ax/metrics/branin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -117,8 +118,7 @@ def fetch_trial_data(

# pyre-fixme[14]: `f` overrides method defined in `NoisyFunctionMapMetric`
# inconsistently.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray, timestamp: int) -> Mapping[str, Any]:
def f(self, x: npt.NDArray, timestamp: int) -> Mapping[str, Any]:
x1, x2 = x

if self.rate is not None:
Expand Down Expand Up @@ -161,8 +161,7 @@ def fetch_trial_data(
**kwargs,
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
def f(self, x: npt.NDArray) -> Mapping[str, Any]:
if self.index < len(FIDELITY):
self.index += 1

Expand Down
7 changes: 3 additions & 4 deletions ax/metrics/hartmann6.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6


class Hartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, hartmann6(x))


class AugmentedHartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, aug_hartmann6(x))
4 changes: 2 additions & 2 deletions ax/metrics/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric


class L2NormMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return np.sqrt((x**2).sum())
4 changes: 2 additions & 2 deletions ax/metrics/noisy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
Expand Down Expand Up @@ -104,8 +105,7 @@ def _evaluate(self, params: TParameterization) -> float:
x = np.array([params[p] for p in self.param_names])
return self.f(x)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions ax/metrics/noisy_function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -112,7 +113,6 @@ def fetch_trial_data(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
def f(self, x: npt.NDArray) -> Mapping[str, Any]:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
4 changes: 2 additions & 2 deletions ax/metrics/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
Expand Down Expand Up @@ -46,8 +47,7 @@ class SklearnDataset(Enum):

@lru_cache(maxsize=8)
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def _get_data(dataset) -> dict[str, np.ndarray]:
def _get_data(dataset) -> dict[str, npt.NDArray]:
"""Return sklearn dataset, loading and caching if necessary."""
if dataset is SklearnDataset.DIGITS:
return datasets.load_digits()
Expand Down
13 changes: 5 additions & 8 deletions ax/modelbridge/best_model_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Union

import numpy as np
import numpy.typing as npt
from ax.exceptions.core import UserInputError
from ax.modelbridge.model_spec import ModelSpec
from ax.utils.common.base import Base
Expand Down Expand Up @@ -45,17 +46,13 @@ class ReductionCriterion(Enum):

# NOTE: Callables need to be wrapped in `partial` to be registered as members.
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean)
MEAN: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.mean)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min)
MIN: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.min)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max)
MAX: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.max)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def __call__(self, array_like: ARRAYLIKE) -> np.ndarray:
def __call__(self, array_like: ARRAYLIKE) -> npt.NDArray:
return self.value(array_like)


Expand Down
Loading

0 comments on commit c8a0f38

Please sign in to comment.