Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing pyre mode header #2907

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ax/analysis/old/helpers/cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def get_plotting_limit_ignore_outliers(

x_np = np.array(x)
# TODO: replace interpolation->method once it becomes standard.
# pyre-fixme[28]: Unexpected keyword argument `interpolation`.
q1 = np.nanpercentile(x_np, q=25, interpolation="lower").min()
# pyre-fixme[28]: Unexpected keyword argument `interpolation`.
q3 = np.nanpercentile(x_np, q=75, interpolation="higher").max()
quartile_difference = q3 - q1

Expand Down
4 changes: 4 additions & 0 deletions ax/analysis/plotly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import numpy as np
import torch
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -90,6 +92,8 @@ def get_constraint_violated_probabilities(
list(feasibility_probabilities.values()), axis=0
)

# pyre-fixme[7]: Expected `Dict[str, List[float]]` but got `Dict[str,
# ndarray[typing.Any, dtype[typing.Any]]]`.
return {
metric_name: 1 - feasibility_probabilities[metric_name]
for metric_name in feasibility_probabilities
Expand Down
2 changes: 2 additions & 0 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@


def compute_score_trace(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_trace: np.ndarray,
num_baseline_trials: int,
problem: BenchmarkProblem,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
"""Computes a score trace from the optimization trace."""

Expand Down
5 changes: 5 additions & 0 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ class BenchmarkResult(Base):
name: str
seed: int

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
oracle_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
inference_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
optimization_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
score_trace: ndarray

fit_time: float
Expand Down Expand Up @@ -156,6 +160,7 @@ def from_benchmark_results(


def _get_stats(
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
step_data: Iterable[np.ndarray],
percentiles: list[float],
) -> dict[str, list[float]]:
Expand Down
5 changes: 5 additions & 0 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train_and_evaluate(
total += labels.size(0)
correct += (predicted == labels).sum()

# pyre-fixme[7]: Expected `Tensor` but got `float`.
return correct / total


Expand Down Expand Up @@ -165,7 +166,11 @@ def evaluate_true(self, params: Mapping[str, int | float]) -> Tensor:
return train_and_evaluate(
**params,
device=self.device,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# attribute `train_loader`.
train_loader=self.train_loader,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# attribute `test_loader`.
test_loader=self.test_loader,
)

Expand Down
5 changes: 4 additions & 1 deletion ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BenchmarkRunner(Runner, ABC):
"""

outcome_names: list[str]
# pyre-fixme[8]: Pyre doesn't understand InitVars
# pyre-fixme[16]: Pyre doesn't understand InitVars
search_space_digest: InitVar[SearchSpaceDigest | None] = None
target_fidelity_and_task: Mapping[str, float | int] = field(init=False)

Expand All @@ -68,6 +68,7 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
...

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
"""
Evaluate oracle metric values at a parameterization. In the base class,
Expand Down Expand Up @@ -145,6 +146,8 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
# budget allocation to each arm. This works b/c (i) we assume that
# observations per unit sample budget are i.i.d. and (ii) the
# normalized weights sum to one.
# pyre-fixme[61]: `nlzd_arm_weights` is undefined, or not always
# defined.
std = noise_stds_tsr.to(Y_true) / sqrt(nlzd_arm_weights[arm])
Ystds[arm.name] = std.tolist()
Ys[arm.name] = (Y_true + std * torch.randn_like(Y_true)).tolist()
Expand Down
1 change: 1 addition & 0 deletions ax/benchmark/runners/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,6 @@ def __eq__(self, other: Base) -> bool:
(self.name == other.name)
and (self.outcome_names == other.outcome_names)
and (self.noise_stds == other.noise_stds)
# pyre-fixme[16]: `SurrogateRunner` has no attribute `search_space_digest`.
and (self.search_space_digest == other.search_space_digest)
)
2 changes: 2 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def __init__(
for different purposes (e.g., transfer learning).
"""
# appease pyre
# pyre-fixme[13]: Attribute `_search_space` is never initialized.
self._search_space: SearchSpace
self._status_quo: Arm | None = None
# pyre-fixme[13]: Attribute `_is_test` is never initialized.
self._is_test: bool

self._name = name
Expand Down
8 changes: 7 additions & 1 deletion ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ class ObservationData(Base):
"""

def __init__(
self, metric_names: list[str], means: np.ndarray, covariance: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
self,
metric_names: list[str],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
means: np.ndarray,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
covariance: np.ndarray,
) -> None:
k = len(metric_names)
if means.shape != (k,):
Expand Down
12 changes: 12 additions & 0 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@


class ParameterType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
BOOL: int = 0
# pyre-fixme[35]: Target cannot be annotated.
INT: int = 1
# pyre-fixme[35]: Target cannot be annotated.
FLOAT: int = 2
# pyre-fixme[35]: Target cannot be annotated.
STRING: int = 3

@property
Expand Down Expand Up @@ -143,6 +147,7 @@ def dependents(self) -> dict[TParamValue, list[str]]:
"Only choice hierarchical parameters are currently supported."
)

# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
def clone(self) -> Parameter:
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
pass
Expand Down Expand Up @@ -214,10 +219,17 @@ def summary_dict(
if flags:
summary_dict["flags"] = ", ".join(flags)
if getattr(self, "is_fidelity", False) or getattr(self, "is_task", False):
# pyre-fixme[6]: For 2nd argument expected `str` but got `Union[None,
# bool, float, int, str]`.
summary_dict["target_value"] = self.target_value
if getattr(self, "is_hierarchical", False):
# pyre-fixme[6]: For 2nd argument expected `str` but got
# `Dict[Union[None, bool, float, int, str], List[str]]`.
summary_dict["dependents"] = self.dependents

# pyre-fixme[7]: Expected `Dict[str, Union[None, List[Union[None, bool,
# float, int, str]], List[str], bool, float, int, str]]` but got `Dict[str,
# str]`.
return summary_dict


Expand Down
2 changes: 2 additions & 0 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,9 @@ class RobustSearchSpaceDigest:
Only relevant if paired with a `distribution_sampler`.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_param_perturbations: Callable[[], np.ndarray] | None = None
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_environmental: Callable[[], np.ndarray] | None = None
environmental_variables: list[str] = field(default_factory=list)
multiplicative: bool = False
Expand Down
6 changes: 6 additions & 0 deletions ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ def test_ObservationsFromDataWithFidelities(self) -> None:
self.assertEqual(obs.features.parameters, t["updated_parameters"])
self.assertEqual(obs.features.trial_index, t["trial_index"])
self.assertEqual(obs.data.metric_names, [t["metric_name"]])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, t["mean_t"]))
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"]))
self.assertEqual(obs.arm_name, t["arm_name"])

Expand Down Expand Up @@ -484,7 +486,9 @@ def test_ObservationsFromMapData(self) -> None:
self.assertEqual(obs.features.parameters, t["updated_parameters"])
self.assertEqual(obs.features.trial_index, t["trial_index"])
self.assertEqual(obs.data.metric_names, [t["metric_name"]])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, t["mean_t"]))
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"]))
self.assertEqual(obs.arm_name, t["arm_name"])
self.assertEqual(obs.features.metadata, {"timestamp": t["timestamp"]})
Expand Down Expand Up @@ -828,8 +832,10 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
0,
)
self.assertEqual(obs.data.metric_names, obs_truth["metric_names"][i])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i]))
self.assertTrue(
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtyp...
np.array_equal(obs.data.covariance, obs_truth["covariance"][i])
)
self.assertEqual(obs.arm_name, obs_truth["arm_name"][i])
Expand Down
5 changes: 5 additions & 0 deletions ax/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TParamValueList = list[TParamValue] # a parameterization without the keys
TContextStratum = Optional[dict[str, Union[str, float, int]]]

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
TBounds = Optional[tuple[np.ndarray, np.ndarray]]
TModelMean = dict[str, list[float]]
TModelCov = dict[str, dict[str, list[float]]]
Expand All @@ -29,6 +30,8 @@
# ( { metric -> mean }, { metric -> { other_metric -> covariance } } ).
TModelPredictArm = tuple[dict[str, float], Optional[dict[str, dict[str, float]]]]

# pyre-fixme[24]: Generic type `np.floating` expects 1 type parameter.
# pyre-fixme[24]: Generic type `np.integer` expects 1 type parameter.
FloatLike = Union[int, float, np.floating, np.integer]
SingleMetricDataTuple = tuple[FloatLike, Optional[FloatLike]]
SingleMetricData = Union[FloatLike, tuple[FloatLike, Optional[FloatLike]]]
Expand Down Expand Up @@ -70,7 +73,9 @@
class ComparisonOp(enum.Enum):
"""Class for enumerating comparison operations."""

# pyre-fixme[35]: Target cannot be annotated.
GEQ: int = 0
# pyre-fixme[35]: Target cannot be annotated.
LEQ: int = 1


Expand Down
6 changes: 5 additions & 1 deletion ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ def _get_missing_arm_trial_pairs(data: Data, metric_name: str) -> set[TArmTrial]


def best_feasible_objective(
optimization_config: OptimizationConfig, values: dict[str, np.ndarray]
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_config: OptimizationConfig,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
values: dict[str, np.ndarray],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
"""Compute the best feasible objective value found by each iteration.

Expand Down
3 changes: 3 additions & 0 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ class EarlyStoppingTrainingData:
which data come from the same arm.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
X: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Y: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Yvar: np.ndarray
arm_names: list[str | None]

Expand Down
3 changes: 3 additions & 0 deletions ax/metrics/branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@


class BraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
x1, x2 = x
return checked_cast(float, branin(x1=x1, x2=x2))


class NegativeBraninMetric(BraninMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
fpos = super().f(x)
return -fpos


class AugmentedBraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, aug_branin(x))
2 changes: 2 additions & 0 deletions ax/metrics/branin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def fetch_trial_data(

# pyre-fixme[14]: `f` overrides method defined in `NoisyFunctionMapMetric`
# inconsistently.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray, timestamp: int) -> Mapping[str, Any]:
x1, x2 = x

Expand Down Expand Up @@ -160,6 +161,7 @@ def fetch_trial_data(
**kwargs,
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
if self.index < len(FIDELITY):
self.index += 1
Expand Down
2 changes: 2 additions & 0 deletions ax/metrics/chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@


class ChemistryProblemType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
SUZUKI: str = "suzuki"
# pyre-fixme[35]: Target cannot be annotated.
DIRECT_ARYLATION: str = "direct_arylation"


Expand Down
2 changes: 2 additions & 0 deletions ax/metrics/hartmann6.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@


class Hartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, hartmann6(x))


class AugmentedHartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, aug_hartmann6(x))
1 change: 1 addition & 0 deletions ax/metrics/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@


class L2NormMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return np.sqrt((x**2).sum())
1 change: 1 addition & 0 deletions ax/metrics/noisy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _evaluate(self, params: TParameterization) -> float:
x = np.array([params[p] for p in self.param_names])
return self.f(x)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions ax/metrics/noisy_function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def fetch_trial_data(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
6 changes: 6 additions & 0 deletions ax/metrics/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,24 @@


class SklearnModelType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
RF: str = "rf"
# pyre-fixme[35]: Target cannot be annotated.
NN: str = "nn"


class SklearnDataset(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DIGITS: str = "digits"
# pyre-fixme[35]: Target cannot be annotated.
BOSTON: str = "boston"
# pyre-fixme[35]: Target cannot be annotated.
CANCER: str = "cancer"


@lru_cache(maxsize=8)
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def _get_data(dataset) -> dict[str, np.ndarray]:
"""Return sklearn dataset, loading and caching if necessary."""
if dataset is SklearnDataset.DIGITS:
Expand Down
1 change: 1 addition & 0 deletions ax/metrics/tests/test_chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


class DummyEnum(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DUMMY: str = "dummy"


Expand Down
1 change: 1 addition & 0 deletions ax/metrics/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


class DummyEnum(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DUMMY: str = "dummy"


Expand Down
Loading