Skip to content

Commit

Permalink
Enable type checking for NumPy types
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jul 20, 2022
1 parent bb40791 commit 9665120
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 33 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ repos:
hooks:
- id: mypy
additional_dependencies:
- numpy>=1.20
- types-filelock
- types-setuptools
6 changes: 3 additions & 3 deletions aesara/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,14 +1671,14 @@ def equal_computations(

for x, y in zip(xs, ys):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return cast(bool, np.array_equal(x, y))
return np.array_equal(x, y)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return cast(bool, np.array_equal(y.data, x))
return np.array_equal(y.data, x)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return cast(bool, np.array_equal(x.data, y))
return np.array_equal(x.data, y)
return False
if x.owner and not y.owner:
return False
Expand Down
24 changes: 9 additions & 15 deletions aesara/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,25 +544,19 @@ def get_c_macros(

vname = variable_names[i]

macro_name = "DTYPE_" + vname
macro_value = "npy_" + v.type.dtype

define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_items = (f"DTYPE_{vname}", f"npy_{v.type.dtype}")
define_macros.append(define_template % macro_items)
undef_macros.append(undef_template % macro_items[0])

d = np.dtype(v.type.dtype)

macro_name = "TYPENUM_" + vname
macro_value = d.num

define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)

macro_name = "ITEMSIZE_" + vname
macro_value = d.itemsize
macro_items_2 = (f"TYPENUM_{vname}", d.num)
define_macros.append(define_template % macro_items_2)
undef_macros.append(undef_template % macro_items_2[0])

define_macros.append(define_template % (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
macro_items_3 = (f"ITEMSIZE_{vname}", d.itemsize)
define_macros.append(define_template % macro_items_3)
undef_macros.append(undef_template % macro_items_3[0])

# Generate a macro to mark code as being apply-specific
define_macros.append(define_template % ("APPLY_SPECIFIC(str)", f"str##_{name}"))
Expand Down
14 changes: 7 additions & 7 deletions aesara/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]


def debugprint(
obj: Union[
graph_like: Union[
Union[Variable, Apply, Function, FunctionGraph],
Sequence[Union[Variable, Apply, Function, FunctionGraph]],
],
Expand Down Expand Up @@ -139,7 +139,7 @@ def debugprint(
Parameters
----------
obj
graph_like
The object(s) to be printed.
depth
Print graph to this depth (``-1`` for unlimited).
Expand All @@ -149,7 +149,7 @@ def debugprint(
When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to
`sys.stdout`.
ids
id_type
Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
Expand Down Expand Up @@ -213,12 +213,12 @@ def debugprint(
topo_orders: List[Optional[List[Apply]]] = []
storage_maps: List[Optional[StorageMapType]] = []

if isinstance(obj, (list, tuple, set)):
lobj = obj
if isinstance(graph_like, (list, tuple, set)):
graphs = graph_like
else:
lobj = [obj]
graphs = (graph_like,)

for obj in lobj:
for obj in graphs:
if isinstance(obj, Variable):
outputs_to_print.append(obj)
profile_list.append(None)
Expand Down
21 changes: 14 additions & 7 deletions aesara/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Symbolic tensor types and constructor functions."""

from functools import singledispatch
from typing import Any, Callable, NoReturn, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Sequence, Union

from aesara.graph.basic import Constant, Variable
from aesara.graph.op import Op


if TYPE_CHECKING:
from numpy.typing import ArrayLike, NDArray


TensorLike = Union[Variable, Sequence[Variable], "ArrayLike"]


def as_tensor_variable(
x: Any, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
x: TensorLike, name: Optional[str] = None, ndim: Optional[int] = None, **kwargs
) -> "TensorVariable":
"""Convert `x` into an equivalent `TensorVariable`.
Expand Down Expand Up @@ -44,12 +51,12 @@ def as_tensor_variable(

@singledispatch
def _as_tensor_variable(
x, name: Optional[str], ndim: Optional[int], **kwargs
x: TensorLike, name: Optional[str], ndim: Optional[int], **kwargs
) -> "TensorVariable":
raise NotImplementedError(f"Cannot convert {x} to a tensor variable.")
raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.")


def get_vector_length(v: Any):
def get_vector_length(v: TensorLike) -> int:
"""Return the run-time length of a symbolic vector, when possible.
Parameters
Expand Down Expand Up @@ -80,13 +87,13 @@ def get_vector_length(v: Any):


@singledispatch
def _get_vector_length(op: Union[Op, Variable], var: Variable):
def _get_vector_length(op: Union[Op, Variable], var: Variable) -> int:
"""`Op`-based dispatch for `get_vector_length`."""
raise ValueError(f"Length of {var} cannot be determined")


@_get_vector_length.register(Constant)
def _get_vector_length_Constant(var_inst, var):
def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
return len(var.data)


Expand Down
2 changes: 1 addition & 1 deletion aesara/tensor/random/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RandomType(Type[T]):

@staticmethod
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]


class RandomStateType(RandomType[np.random.RandomState]):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ warn_unreachable = True
show_error_codes = True
allow_redefinition = False
files = aesara,tests
plugins = numpy.typing.mypy_plugin

[mypy-versioneer]
check_untyped_defs = False
Expand Down

0 comments on commit 9665120

Please sign in to comment.