Skip to content

Commit

Permalink
Fix np.ndarray type annotation (#2983)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2983

Reviewed By: paschai

Differential Revision: D65146743
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 29, 2024
1 parent fed5c55 commit fee5632
Show file tree
Hide file tree
Showing 71 changed files with 324 additions and 608 deletions.
7 changes: 3 additions & 4 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from time import monotonic, time

import numpy as np
import numpy.typing as npt

from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.benchmark.benchmark_problem import BenchmarkProblem
Expand All @@ -41,12 +42,10 @@


def compute_score_trace(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_trace: np.ndarray,
optimization_trace: npt.NDArray,
num_baseline_trials: int,
problem: BenchmarkProblem,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
) -> npt.NDArray:
"""Computes a score trace from the optimization trace."""

# Use the first GenerationStep's best found point as baseline. Sometimes (ex. in
Expand Down
19 changes: 7 additions & 12 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from collections.abc import Iterable
from dataclasses import dataclass

import numpy as np
import numpy.typing as npt
from ax.core.experiment import Experiment
from ax.utils.common.base import Base
from numpy import nanmean, nanquantile, ndarray
from numpy import nanmean, nanquantile
from pandas import DataFrame
from scipy.stats import sem

Expand Down Expand Up @@ -78,14 +78,10 @@ class BenchmarkResult(Base):
name: str
seed: int

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
oracle_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
inference_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
optimization_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
score_trace: ndarray
oracle_trace: npt.NDArray
inference_trace: npt.NDArray
optimization_trace: npt.NDArray
score_trace: npt.NDArray

fit_time: float
gen_time: float
Expand Down Expand Up @@ -160,8 +156,7 @@ def from_benchmark_results(


def _get_stats(
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
step_data: Iterable[np.ndarray],
step_data: Iterable[npt.NDArray],
percentiles: list[float],
) -> dict[str, list[float]]:
quantiles = []
Expand Down
1 change: 0 additions & 1 deletion ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def train_and_evaluate(
total += labels.size(0)
correct += (predicted == labels).sum().item()

# pyre-fixme[7]: Expected `Tensor` but got `float`.
return correct / total


Expand Down
6 changes: 3 additions & 3 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from math import sqrt
from typing import Any

import numpy.typing as npt

import torch
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
Expand All @@ -22,7 +24,6 @@
from ax.utils.common.serialization import TClassDecoderRegistry, TDecoderRegistry

from ax.utils.common.typeutils import checked_cast
from numpy import ndarray
from torch import Tensor


Expand Down Expand Up @@ -68,8 +69,7 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
...

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> npt.NDArray:
"""
Evaluate oracle metric values at a parameterization. In the base class,
oracle values are underlying noiseless function values evaluated at the
Expand Down
2 changes: 0 additions & 2 deletions ax/benchmark/tests/runners/test_botorch_test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,6 @@ def test_synthetic_runner(self) -> None:
nullcontext()
if not isinstance(test_problem, SurrogateTestFunction)
else patch.object(
# pyre-fixme: ParamBasedTestProblem` has no attribute
# `_surrogate`.
runner.test_problem._surrogate,
"predict",
return_value=({"branin": [4.2]}, None),
Expand Down
1 change: 0 additions & 1 deletion ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,6 @@ def normalized_arm_weights(
weights = np.array(self.weights)
if trunc_digits is not None:
atomic_weight = 10**-trunc_digits
# pyre-fixme[16]: `float` has no attribute `astype`.
int_weights = (
(total / atomic_weight) * (weights / np.sum(weights))
).astype(int)
Expand Down
1 change: 0 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
}


# pyre-fixme[13]: Attribute `_search_space` is never initialized.
class Experiment(Base):
"""Base class for defining an experiment."""

Expand Down
8 changes: 3 additions & 5 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import ax.core.experiment as experiment
import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import NON_ABANDONED_STATUSES, TrialStatus
Expand Down Expand Up @@ -187,13 +188,10 @@ class ObservationData(Base):
"""

def __init__(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
self,
metric_names: list[str],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
means: np.ndarray,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
covariance: np.ndarray,
means: npt.NDArray,
covariance: npt.NDArray,
) -> None:
k = len(metric_names)
if means.shape != (k,):
Expand Down
1 change: 0 additions & 1 deletion ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def dependents(self) -> dict[TParamValue, list[str]]:

# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
def clone(self) -> Parameter:
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
pass

@property
Expand Down
8 changes: 3 additions & 5 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from random import choice, uniform
from typing import Sequence

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax import core
from ax.core.arm import Arm
Expand Down Expand Up @@ -1103,10 +1103,8 @@ class RobustSearchSpaceDigest:
Only relevant if paired with a `distribution_sampler`.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_param_perturbations: Callable[[], np.ndarray] | None = None
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_environmental: Callable[[], np.ndarray] | None = None
sample_param_perturbations: Callable[[], npt.NDArray] | None = None
sample_environmental: Callable[[], npt.NDArray] | None = None
environmental_variables: list[str] = field(default_factory=list)
multiplicative: bool = False

Expand Down
8 changes: 3 additions & 5 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import NamedTuple

import numpy as np
import numpy.typing as npt
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import BatchTrial
Expand Down Expand Up @@ -129,12 +130,9 @@ def _get_missing_arm_trial_pairs(data: Data, metric_name: str) -> set[TArmTrial]


def best_feasible_objective(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_config: OptimizationConfig,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
values: dict[str, np.ndarray],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
values: dict[str, npt.NDArray],
) -> npt.NDArray:
"""Compute the best feasible objective value found by each iteration.
Args:
Expand Down
11 changes: 4 additions & 7 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dataclasses import dataclass
from logging import Logger

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import TrialStatus
from ax.core.data import Data
Expand Down Expand Up @@ -54,12 +54,9 @@ class EarlyStoppingTrainingData:
which data come from the same arm.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
X: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Y: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Yvar: np.ndarray
X: npt.NDArray
Y: npt.NDArray
Yvar: npt.NDArray
arm_names: list[str | None]


Expand Down
11 changes: 4 additions & 7 deletions ax/metrics/branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,24 @@

# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import aug_branin, branin


class BraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
x1, x2 = x
return checked_cast(float, branin(x1=x1, x2=x2))


class NegativeBraninMetric(BraninMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
fpos = super().f(x)
return -fpos


class AugmentedBraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, aug_branin(x))
7 changes: 3 additions & 4 deletions ax/metrics/branin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -117,8 +118,7 @@ def fetch_trial_data(

# pyre-fixme[14]: `f` overrides method defined in `NoisyFunctionMapMetric`
# inconsistently.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray, timestamp: int) -> Mapping[str, Any]:
def f(self, x: npt.NDArray, timestamp: int) -> Mapping[str, Any]:
x1, x2 = x

if self.rate is not None:
Expand Down Expand Up @@ -161,8 +161,7 @@ def fetch_trial_data(
**kwargs,
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
def f(self, x: npt.NDArray) -> Mapping[str, Any]:
if self.index < len(FIDELITY):
self.index += 1

Expand Down
8 changes: 3 additions & 5 deletions ax/metrics/hartmann6.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,17 @@

# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric
from ax.utils.common.typeutils import checked_cast
from ax.utils.measurement.synthetic_functions import aug_hartmann6, hartmann6


class Hartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, hartmann6(x))


class AugmentedHartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return checked_cast(float, aug_hartmann6(x))
4 changes: 2 additions & 2 deletions ax/metrics/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# pyre-strict

import numpy as np
import numpy.typing as npt
from ax.metrics.noisy_function import NoisyFunctionMetric


class L2NormMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
return np.sqrt((x**2).sum())
4 changes: 2 additions & 2 deletions ax/metrics/noisy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.data import Data
Expand Down Expand Up @@ -104,8 +105,7 @@ def _evaluate(self, params: TParameterization) -> float:
x = np.array([params[p] for p in self.param_names])
return self.f(x)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
def f(self, x: npt.NDArray) -> float:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError

Expand Down
4 changes: 2 additions & 2 deletions ax/metrics/noisy_function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.base_trial import BaseTrial
from ax.core.map_data import MapData, MapKeyInfo
Expand Down Expand Up @@ -112,7 +113,6 @@ def fetch_trial_data(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
def f(self, x: npt.NDArray) -> Mapping[str, Any]:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
5 changes: 2 additions & 3 deletions ax/metrics/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from math import sqrt
from typing import Any

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import BaseTrial
Expand Down Expand Up @@ -46,8 +46,7 @@ class SklearnDataset(Enum):

@lru_cache(maxsize=8)
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def _get_data(dataset) -> dict[str, np.ndarray]:
def _get_data(dataset) -> dict[str, npt.NDArray]:
"""Return sklearn dataset, loading and caching if necessary."""
if dataset is SklearnDataset.DIGITS:
return datasets.load_digits()
Expand Down
Loading

0 comments on commit fee5632

Please sign in to comment.