From 9665120e4b50b4a594a57ee08fdefb17eb4ff720 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 20 Jul 2022 15:00:29 -0500 Subject: [PATCH] Enable type checking for NumPy types --- .pre-commit-config.yaml | 1 + aesara/graph/basic.py | 6 +++--- aesara/link/c/op.py | 24 +++++++++--------------- aesara/printing.py | 14 +++++++------- aesara/tensor/__init__.py | 21 ++++++++++++++------- aesara/tensor/random/type.py | 2 +- setup.cfg | 1 + 7 files changed, 36 insertions(+), 33 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0b747db183..8073e0f67b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,5 +51,6 @@ repos: hooks: - id: mypy additional_dependencies: + - numpy>=1.20 - types-filelock - types-setuptools diff --git a/aesara/graph/basic.py b/aesara/graph/basic.py index 923a557006..ae97742432 100644 --- a/aesara/graph/basic.py +++ b/aesara/graph/basic.py @@ -1671,14 +1671,14 @@ def equal_computations( for x, y in zip(xs, ys): if not isinstance(x, Variable) and not isinstance(y, Variable): - return cast(bool, np.array_equal(x, y)) + return np.array_equal(x, y) if not isinstance(x, Variable): if isinstance(y, Constant): - return cast(bool, np.array_equal(y.data, x)) + return np.array_equal(y.data, x) return False if not isinstance(y, Variable): if isinstance(x, Constant): - return cast(bool, np.array_equal(x.data, y)) + return np.array_equal(x.data, y) return False if x.owner and not y.owner: return False diff --git a/aesara/link/c/op.py b/aesara/link/c/op.py index 3627eb6d20..b15679324a 100644 --- a/aesara/link/c/op.py +++ b/aesara/link/c/op.py @@ -544,25 +544,19 @@ def get_c_macros( vname = variable_names[i] - macro_name = "DTYPE_" + vname - macro_value = "npy_" + v.type.dtype - - define_macros.append(define_template % (macro_name, macro_value)) - undef_macros.append(undef_template % macro_name) + macro_items = (f"DTYPE_{vname}", f"npy_{v.type.dtype}") + define_macros.append(define_template % macro_items) + undef_macros.append(undef_template % macro_items[0]) d = np.dtype(v.type.dtype) - macro_name = "TYPENUM_" + vname - macro_value = d.num - - define_macros.append(define_template % (macro_name, macro_value)) - undef_macros.append(undef_template % macro_name) - - macro_name = "ITEMSIZE_" + vname - macro_value = d.itemsize + macro_items_2 = (f"TYPENUM_{vname}", d.num) + define_macros.append(define_template % macro_items_2) + undef_macros.append(undef_template % macro_items_2[0]) - define_macros.append(define_template % (macro_name, macro_value)) - undef_macros.append(undef_template % macro_name) + macro_items_3 = (f"ITEMSIZE_{vname}", d.itemsize) + define_macros.append(define_template % macro_items_3) + undef_macros.append(undef_template % macro_items_3[0]) # Generate a macro to mark code as being apply-specific define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}")) diff --git a/aesara/printing.py b/aesara/printing.py index e0fe359662..7f5fd2155d 100644 --- a/aesara/printing.py +++ b/aesara/printing.py @@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str] def debugprint( - obj: Union[ + graph_like: Union[ Union[Variable, Apply, Function, FunctionGraph], Sequence[Union[Variable, Apply, Function, FunctionGraph]], ], @@ -139,7 +139,7 @@ def debugprint( Parameters ---------- - obj + graph_like The object(s) to be printed. depth Print graph to this depth (``-1`` for unlimited). @@ -149,7 +149,7 @@ def debugprint( When `file` extends `TextIO`, print to it; when `file` is equal to ``"str"``, return a string; when `file` is ``None``, print to `sys.stdout`. - ids + id_type Determines the type of identifier used for `Variable`\s: - ``"id"``: print the python id value, - ``"int"``: print integer character, @@ -213,12 +213,12 @@ def debugprint( topo_orders: List[Optional[List[Apply]]] = [] storage_maps: List[Optional[StorageMapType]] = [] - if isinstance(obj, (list, tuple, set)): - lobj = obj + if isinstance(graph_like, (list, tuple, set)): + graphs = graph_like else: - lobj = [obj] + graphs = (graph_like,) - for obj in lobj: + for obj in graphs: if isinstance(obj, Variable): outputs_to_print.append(obj) profile_list.append(None) diff --git a/aesara/tensor/__init__.py b/aesara/tensor/__init__.py index ab7e6c604f..ef865e1bec 100644 --- a/aesara/tensor/__init__.py +++ b/aesara/tensor/__init__.py @@ -1,14 +1,21 @@ """Symbolic tensor types and constructor functions.""" from functools import singledispatch -from typing import Any, Callable, NoReturn, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Sequence, Union from aesara.graph.basic import Constant, Variable from aesara.graph.op import Op +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + +TensorLike = Union[Variable, Sequence[Variable], "ArrayLike"] + + def as_tensor_variable( - x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs + x: TensorLike, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs ) -> "TensorVariable": """Convert `x` into an equivalent `TensorVariable`. @@ -44,12 +51,12 @@ def as_tensor_variable( @singledispatch def _as_tensor_variable( - x, name: Optional[str], ndim: Optional[int], **kwargs + x: TensorLike, name: Optional[str], ndim: Optional[int], **kwargs ) -> "TensorVariable": - raise NotImplementedError(f"Cannot convert {x} to a tensor variable.") + raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.") -def get_vector_length(v: Any): +def get_vector_length(v: TensorLike) -> int: """Return the run-time length of a symbolic vector, when possible. Parameters @@ -80,13 +87,13 @@ def get_vector_length(v: Any): @singledispatch -def _get_vector_length(op: Union[Op, Variable], var: Variable): +def _get_vector_length(op: Union[Op, Variable], var: Variable) -> int: """`Op`-based dispatch for `get_vector_length`.""" raise ValueError(f"Length of {var} cannot be determined") @_get_vector_length.register(Constant) -def _get_vector_length_Constant(var_inst, var): +def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int: return len(var.data) diff --git a/aesara/tensor/random/type.py b/aesara/tensor/random/type.py index cdea596347..5c897473dc 100644 --- a/aesara/tensor/random/type.py +++ b/aesara/tensor/random/type.py @@ -28,7 +28,7 @@ class RandomType(Type[T]): @staticmethod def may_share_memory(a: T, b: T): - return a._bit_generator is b._bit_generator + return a._bit_generator is b._bit_generator # type: ignore[attr-defined] class RandomStateType(RandomType[np.random.RandomState]): diff --git a/setup.cfg b/setup.cfg index 6c6b4dea35..68d01e2545 100644 --- a/setup.cfg +++ b/setup.cfg @@ -83,6 +83,7 @@ warn_unreachable = True show_error_codes = True allow_redefinition = False files = aesara,tests +plugins = numpy.typing.mypy_plugin [mypy-versioneer] check_untyped_defs = False