Skip to content

Commit

Permalink
pyre upgrade
Browse files Browse the repository at this point in the history
Differential Revision: D64542424
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Oct 17, 2024
1 parent 600b4b1 commit f02a90a
Show file tree
Hide file tree
Showing 85 changed files with 593 additions and 38 deletions.
2 changes: 2 additions & 0 deletions ax/analysis/old/helpers/cross_validation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def get_plotting_limit_ignore_outliers(

x_np = np.array(x)
# TODO: replace interpolation->method once it becomes standard.
# pyre-fixme[28]: Unexpected keyword argument `interpolation`.
q1 = np.nanpercentile(x_np, q=25, interpolation="lower").min()
# pyre-fixme[28]: Unexpected keyword argument `interpolation`.
q3 = np.nanpercentile(x_np, q=75, interpolation="higher").max()
quartile_difference = q3 - q1

Expand Down
2 changes: 2 additions & 0 deletions ax/analysis/plotly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def get_constraint_violated_probabilities(
list(feasibility_probabilities.values()), axis=0
)

# pyre-fixme[7]: Expected `Dict[str, List[float]]` but got `Dict[str,
# ndarray[typing.Any, dtype[typing.Any]]]`.
return {
metric_name: 1 - feasibility_probabilities[metric_name]
for metric_name in feasibility_probabilities
Expand Down
2 changes: 2 additions & 0 deletions ax/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@


def compute_score_trace(
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_trace: np.ndarray,
num_baseline_trials: int,
problem: BenchmarkProblem,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
"""Computes a score trace from the optimization trace."""

Expand Down
5 changes: 5 additions & 0 deletions ax/benchmark/benchmark_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ class BenchmarkResult(Base):
name: str
seed: int

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
oracle_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
inference_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
optimization_trace: ndarray
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
score_trace: ndarray

fit_time: float
Expand Down Expand Up @@ -156,6 +160,7 @@ def from_benchmark_results(


def _get_stats(
# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
step_data: Iterable[np.ndarray],
percentiles: list[float],
) -> dict[str, list[float]]:
Expand Down
5 changes: 5 additions & 0 deletions ax/benchmark/problems/hpo/torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def train_and_evaluate(
total += labels.size(0)
correct += (predicted == labels).sum()

# pyre-fixme[7]: Expected `Tensor` but got `float`.
return correct / total


Expand Down Expand Up @@ -165,7 +166,11 @@ def evaluate_true(self, params: Mapping[str, int | float]) -> Tensor:
return train_and_evaluate(
**params,
device=self.device,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# attribute `train_loader`.
train_loader=self.train_loader,
# pyre-fixme[16]: `PyTorchCNNTorchvisionParamBasedProblem` has no
# attribute `test_loader`.
test_loader=self.test_loader,
)

Expand Down
3 changes: 3 additions & 0 deletions ax/benchmark/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_Y_true(self, params: Mapping[str, TParamValue]) -> Tensor:
"""
...

# pyre-fixme[24]: Generic type `ndarray` expects 2 type parameters.
def evaluate_oracle(self, parameters: Mapping[str, TParamValue]) -> ndarray:
"""
Evaluate oracle metric values at a parameterization. In the base class,
Expand Down Expand Up @@ -149,6 +150,8 @@ def run(self, trial: BaseTrial) -> dict[str, Any]:
# budget allocation to each arm. This works b/c (i) we assume that
# observations per unit sample budget are i.i.d. and (ii) the
# normalized weights sum to one.
# pyre-fixme[61]: `nlzd_arm_weights` is undefined, or not always
# defined.
std = noise_stds_tsr.to(Y_true) / sqrt(nlzd_arm_weights[arm])
Ystds[arm.name] = std.tolist()
Ys[arm.name] = (Y_true + std * torch.randn_like(Y_true)).tolist()
Expand Down
2 changes: 2 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,10 @@ def __init__(
for different purposes (e.g., transfer learning).
"""
# appease pyre
# pyre-fixme[13]: Attribute `_search_space` is never initialized.
self._search_space: SearchSpace
self._status_quo: Arm | None = None
# pyre-fixme[13]: Attribute `_is_test` is never initialized.
self._is_test: bool

self._name = name
Expand Down
8 changes: 7 additions & 1 deletion ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ class ObservationData(Base):
"""

def __init__(
self, metric_names: list[str], means: np.ndarray, covariance: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
self,
metric_names: list[str],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
means: np.ndarray,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
covariance: np.ndarray,
) -> None:
k = len(metric_names)
if means.shape != (k,):
Expand Down
12 changes: 12 additions & 0 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,13 @@


class ParameterType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
BOOL: int = 0
# pyre-fixme[35]: Target cannot be annotated.
INT: int = 1
# pyre-fixme[35]: Target cannot be annotated.
FLOAT: int = 2
# pyre-fixme[35]: Target cannot be annotated.
STRING: int = 3

@property
Expand Down Expand Up @@ -143,6 +147,7 @@ def dependents(self) -> dict[TParamValue, list[str]]:
"Only choice hierarchical parameters are currently supported."
)

# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
def clone(self) -> Parameter:
# pyre-fixme[7]: Expected `Parameter` but got implicit return value of `None`.
pass
Expand Down Expand Up @@ -214,10 +219,17 @@ def summary_dict(
if flags:
summary_dict["flags"] = ", ".join(flags)
if getattr(self, "is_fidelity", False) or getattr(self, "is_task", False):
# pyre-fixme[6]: For 2nd argument expected `str` but got `Union[None,
# bool, float, int, str]`.
summary_dict["target_value"] = self.target_value
if getattr(self, "is_hierarchical", False):
# pyre-fixme[6]: For 2nd argument expected `str` but got
# `Dict[Union[None, bool, float, int, str], List[str]]`.
summary_dict["dependents"] = self.dependents

# pyre-fixme[7]: Expected `Dict[str, Union[None, List[Union[None, bool,
# float, int, str]], List[str], bool, float, int, str]]` but got `Dict[str,
# str]`.
return summary_dict


Expand Down
2 changes: 2 additions & 0 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,9 @@ class RobustSearchSpaceDigest:
Only relevant if paired with a `distribution_sampler`.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_param_perturbations: Callable[[], np.ndarray] | None = None
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
sample_environmental: Callable[[], np.ndarray] | None = None
environmental_variables: list[str] = field(default_factory=list)
multiplicative: bool = False
Expand Down
6 changes: 6 additions & 0 deletions ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,9 @@ def test_ObservationsFromDataWithFidelities(self) -> None:
self.assertEqual(obs.features.parameters, t["updated_parameters"])
self.assertEqual(obs.features.trial_index, t["trial_index"])
self.assertEqual(obs.data.metric_names, [t["metric_name"]])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, t["mean_t"]))
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"]))
self.assertEqual(obs.arm_name, t["arm_name"])

Expand Down Expand Up @@ -484,7 +486,9 @@ def test_ObservationsFromMapData(self) -> None:
self.assertEqual(obs.features.parameters, t["updated_parameters"])
self.assertEqual(obs.features.trial_index, t["trial_index"])
self.assertEqual(obs.data.metric_names, [t["metric_name"]])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, t["mean_t"]))
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.covariance, t["covariance_t"]))
self.assertEqual(obs.arm_name, t["arm_name"])
self.assertEqual(obs.features.metadata, {"timestamp": t["timestamp"]})
Expand Down Expand Up @@ -828,8 +832,10 @@ def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
0,
)
self.assertEqual(obs.data.metric_names, obs_truth["metric_names"][i])
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtype[ty...
self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i]))
self.assertTrue(
# pyre-fixme[6]: For 2nd argument expected `Union[_SupportsArray[dtyp...
np.array_equal(obs.data.covariance, obs_truth["covariance"][i])
)
self.assertEqual(obs.arm_name, obs_truth["arm_name"][i])
Expand Down
5 changes: 5 additions & 0 deletions ax/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TParamValueList = list[TParamValue] # a parameterization without the keys
TContextStratum = Optional[dict[str, Union[str, float, int]]]

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
TBounds = Optional[tuple[np.ndarray, np.ndarray]]
TModelMean = dict[str, list[float]]
TModelCov = dict[str, dict[str, list[float]]]
Expand All @@ -29,6 +30,8 @@
# ( { metric -> mean }, { metric -> { other_metric -> covariance } } ).
TModelPredictArm = tuple[dict[str, float], Optional[dict[str, dict[str, float]]]]

# pyre-fixme[24]: Generic type `np.floating` expects 1 type parameter.
# pyre-fixme[24]: Generic type `np.integer` expects 1 type parameter.
FloatLike = Union[int, float, np.floating, np.integer]
SingleMetricDataTuple = tuple[FloatLike, Optional[FloatLike]]
SingleMetricData = Union[FloatLike, tuple[FloatLike, Optional[FloatLike]]]
Expand Down Expand Up @@ -70,7 +73,9 @@
class ComparisonOp(enum.Enum):
"""Class for enumerating comparison operations."""

# pyre-fixme[35]: Target cannot be annotated.
GEQ: int = 0
# pyre-fixme[35]: Target cannot be annotated.
LEQ: int = 1


Expand Down
6 changes: 5 additions & 1 deletion ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ def _get_missing_arm_trial_pairs(data: Data, metric_name: str) -> set[TArmTrial]


def best_feasible_objective(
optimization_config: OptimizationConfig, values: dict[str, np.ndarray]
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
optimization_config: OptimizationConfig,
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
values: dict[str, np.ndarray],
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
) -> np.ndarray:
"""Compute the best feasible objective value found by each iteration.
Expand Down
3 changes: 3 additions & 0 deletions ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@ class EarlyStoppingTrainingData:
which data come from the same arm.
"""

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
X: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Y: np.ndarray
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
Yvar: np.ndarray
arm_names: list[str | None]

Expand Down
3 changes: 3 additions & 0 deletions ax/metrics/branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@


class BraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
x1, x2 = x
return checked_cast(float, branin(x1=x1, x2=x2))


class NegativeBraninMetric(BraninMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
fpos = super().f(x)
return -fpos


class AugmentedBraninMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, aug_branin(x))
2 changes: 2 additions & 0 deletions ax/metrics/branin_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def fetch_trial_data(

# pyre-fixme[14]: `f` overrides method defined in `NoisyFunctionMapMetric`
# inconsistently.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray, timestamp: int) -> Mapping[str, Any]:
x1, x2 = x

Expand Down Expand Up @@ -160,6 +161,7 @@ def fetch_trial_data(
**kwargs,
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
if self.index < len(FIDELITY):
self.index += 1
Expand Down
2 changes: 2 additions & 0 deletions ax/metrics/chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@


class ChemistryProblemType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
SUZUKI: str = "suzuki"
# pyre-fixme[35]: Target cannot be annotated.
DIRECT_ARYLATION: str = "direct_arylation"


Expand Down
2 changes: 2 additions & 0 deletions ax/metrics/hartmann6.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@


class Hartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, hartmann6(x))


class AugmentedHartmann6Metric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return checked_cast(float, aug_hartmann6(x))
1 change: 1 addition & 0 deletions ax/metrics/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@


class L2NormMetric(NoisyFunctionMetric):
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
return np.sqrt((x**2).sum())
1 change: 1 addition & 0 deletions ax/metrics/noisy_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _evaluate(self, params: TParameterization) -> float:
x = np.array([params[p] for p in self.param_names])
return self.f(x)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> float:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions ax/metrics/noisy_function_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def fetch_trial_data(
MetricFetchE(message=f"Failed to fetch {self.name}", exception=e)
)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def f(self, x: np.ndarray) -> Mapping[str, Any]:
"""The deterministic function that produces the metric outcomes."""
raise NotImplementedError
6 changes: 6 additions & 0 deletions ax/metrics/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,24 @@


class SklearnModelType(Enum):
# pyre-fixme[35]: Target cannot be annotated.
RF: str = "rf"
# pyre-fixme[35]: Target cannot be annotated.
NN: str = "nn"


class SklearnDataset(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DIGITS: str = "digits"
# pyre-fixme[35]: Target cannot be annotated.
BOSTON: str = "boston"
# pyre-fixme[35]: Target cannot be annotated.
CANCER: str = "cancer"


@lru_cache(maxsize=8)
# pyre-fixme[2]: Parameter must be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def _get_data(dataset) -> dict[str, np.ndarray]:
"""Return sklearn dataset, loading and caching if necessary."""
if dataset is SklearnDataset.DIGITS:
Expand Down
1 change: 1 addition & 0 deletions ax/metrics/tests/test_chemistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


class DummyEnum(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DUMMY: str = "dummy"


Expand Down
1 change: 1 addition & 0 deletions ax/metrics/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


class DummyEnum(Enum):
# pyre-fixme[35]: Target cannot be annotated.
DUMMY: str = "dummy"


Expand Down
8 changes: 8 additions & 0 deletions ax/modelbridge/best_model_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ax.utils.common.base import Base
from ax.utils.common.typeutils import not_none

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
ARRAYLIKE = Union[np.ndarray, list[float], list[np.ndarray]]


Expand All @@ -43,10 +44,17 @@ class ReductionCriterion(Enum):
"""

# NOTE: Callables need to be wrapped in `partial` to be registered as members.
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MEAN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.mean)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MIN: Callable[[ARRAYLIKE], np.ndarray] = partial(np.min)
# pyre-fixme[35]: Target cannot be annotated.
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
MAX: Callable[[ARRAYLIKE], np.ndarray] = partial(np.max)

# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
def __call__(self, array_like: ARRAYLIKE) -> np.ndarray:
return self.value(array_like)

Expand Down
Loading

0 comments on commit f02a90a

Please sign in to comment.