Skip to content

TYP: replace basedmypy with mypy #329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 308 additions & 54 deletions pixi.lock

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ array-api-extra = { path = ".", editable = true }
typing-extensions = ">=4.13.2"
pre-commit = ">=4.2.0"
pylint = ">=3.3.7"
basedmypy = ">=2.10.0"
mypy = ">=1.16.0"
basedpyright = ">=1.29.2"
numpydoc = ">=1.8.0,<2"
# import dependencies for mypy:
Expand Down Expand Up @@ -227,16 +227,17 @@ python_version = "3.10"
warn_unused_configs = true
strict = true
enable_error_code = ["ignore-without-code", "truthy-bool"]
# https://github.com/data-apis/array-api-typing
disallow_any_expr = false
# false positives with input validation
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
disable_error_code = ["no-any-return"]

[[tool.mypy.overrides]]
# slow or unavailable on Windows; do not add to the lint env
module = ["cupy.*", "jax.*", "sparse.*", "torch.*"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["tests/*"]
disable_error_code = ["no-untyped-def"] # test(...) without -> None

# pyright

[tool.basedpyright]
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class _AtOp(Enum):
MAX = "max"

# @override from Python 3.12
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride]
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride]
"""
Return string representation (useful for pytest logs).

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Backend(Enum): # numpydoc ignore=PR02
JAX = "jax.numpy"
JAX_GPU = "jax.numpy:gpu"

def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
def __str__(self) -> str: # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
Expand Down
12 changes: 6 additions & 6 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


@overload
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
def apply_where( # numpydoc ignore=GL08
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
Expand All @@ -46,7 +46,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G


@overload
def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=GL08
def apply_where( # numpydoc ignore=GL08
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
Expand All @@ -57,7 +57,7 @@ def apply_where( # type: ignore[explicit-any,decorated-any] # numpydoc ignore=G
) -> Array: ...


def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
def apply_where( # numpydoc ignore=PR01,PR02
cond: Array,
args: Array | tuple[Array, ...],
f1: Callable[..., Array],
Expand Down Expand Up @@ -143,7 +143,7 @@ def apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,PR02
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)


def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
def _apply_where( # numpydoc ignore=PR01,RT01
cond: Array,
f1: Callable[..., Array],
f2: Callable[..., Array] | None,
Expand Down Expand Up @@ -813,8 +813,7 @@ def pad(
else:
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))

# https://github.com/python/typeshed/issues/13376
slices: list[slice] = [] # type: ignore[explicit-any]
slices: list[slice] = []
newshape: list[int] = []
for ax, w_tpl in enumerate(pad_width_seq):
if len(w_tpl) != 2:
Expand All @@ -826,6 +825,7 @@ def pad(
if w_tpl[0] == 0 and w_tpl[1] == 0:
sl = slice(None, None, None)
else:
stop: int | None
start, stop = w_tpl
stop = None if stop == 0 else -stop

Expand Down
12 changes: 6 additions & 6 deletions src/array_api_extra/_lib/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
from numpy.typing import ArrayLike

NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic # type: ignore[explicit-any]
NumPyObject: TypeAlias = np.ndarray[Any, Any] | np.generic
else:
# Sphinx hack
NumPyObject = Any
Expand All @@ -31,7 +31,7 @@


@overload
def lazy_apply( # type: ignore[decorated-any, valid-type]
def lazy_apply( # type: ignore[valid-type]
func: Callable[P, Array | ArrayLike],
*args: Array | complex | None,
shape: tuple[int | None, ...] | None = None,
Expand All @@ -43,7 +43,7 @@ def lazy_apply( # type: ignore[decorated-any, valid-type]


@overload
def lazy_apply( # type: ignore[decorated-any, valid-type]
def lazy_apply( # type: ignore[valid-type]
func: Callable[P, Sequence[Array | ArrayLike]],
*args: Array | complex | None,
shape: Sequence[tuple[int | None, ...]],
Expand Down Expand Up @@ -313,7 +313,7 @@ def _is_jax_jit_enabled(xp: ModuleType) -> bool: # numpydoc ignore=PR01,RT01
return True


def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
def _lazy_apply_wrapper( # numpydoc ignore=PR01,RT01
func: Callable[..., Array | ArrayLike | Sequence[Array | ArrayLike]],
as_numpy: bool,
multi_output: bool,
Expand All @@ -331,7 +331,7 @@ def _lazy_apply_wrapper( # type: ignore[explicit-any] # numpydoc ignore=PR01,R

# On Dask, @wraps causes the graph key to contain the wrapped function's name
@wraps(func)
def wrapper( # type: ignore[decorated-any,explicit-any]
def wrapper(
*args: Array | complex | None, **kwargs: Any
) -> tuple[Array, ...]: # numpydoc ignore=GL08
args_list = []
Expand All @@ -343,7 +343,7 @@ def wrapper( # type: ignore[decorated-any,explicit-any]
if as_numpy:
import numpy as np

arg = cast(Array, np.asarray(arg)) # type: ignore[bad-cast] # noqa: PLW2901
arg = cast(Array, np.asarray(arg)) # noqa: PLW2901
args_list.append(arg)
assert device is not None

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def _is_materializable(x: Array) -> bool:
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]


def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]:
"""
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
"""
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def is_torch_array(x: object, /) -> TypeGuard[Array]: ...
def is_lazy_array(x: object, /) -> TypeGuard[Array]: ...
def is_writeable_array(x: object, /) -> TypeGuard[Array]: ...
def size(x: Array, /) -> int | None: ...
def to_device( # type: ignore[explicit-any]
def to_device(
x: Array,
device: Device, # pylint: disable=redefined-outer-name
/,
Expand Down
8 changes: 4 additions & 4 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def asarrays(
float: ("real floating", "complex floating"),
complex: "complex floating",
}
kind = same_dtype[type(cast(complex, b))] # type: ignore[index]
kind = same_dtype[type(cast(complex, b))]
if xp.isdtype(a.dtype, kind):
xb = xp.asarray(b, dtype=a.dtype)
else:
Expand Down Expand Up @@ -458,7 +458,7 @@ def persistent_id(
return instances, (f.getvalue(), *rest)


def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any:
"""
Reverse of ``pickle_flatten``.

Expand Down Expand Up @@ -521,7 +521,7 @@ def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
self.obj = obj

@classmethod
def _register(cls): # numpydoc ignore=SS06
def _register(cls) -> None: # numpydoc ignore=SS06
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
Expand Down Expand Up @@ -583,7 +583,7 @@ def f(x: Array, y: float, plus: bool) -> Array:
import jax

@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
def inner( # numpydoc ignore=GL08
wargs: _AutoJITWrapper[Any],
) -> _AutoJITWrapper[T]:
args, kwargs = wargs.obj
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_lib/_utils/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ class DType(Protocol): # pylint: disable=missing-class-docstring
class Device(Protocol): # pylint: disable=missing-class-docstring
pass

SetIndex: TypeAlias = ( # type: ignore[explicit-any]
SetIndex: TypeAlias = (
int | slice | EllipsisType | Array | tuple[int | slice | EllipsisType | Array, ...]
)
GetIndex: TypeAlias = ( # type: ignore[explicit-any]
GetIndex: TypeAlias = (
SetIndex | None | tuple[int | slice | EllipsisType | None | Array, ...]
)

Expand Down
16 changes: 9 additions & 7 deletions src/array_api_extra/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def override(func):
P = ParamSpec("P")
T = TypeVar("T")

_ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any]
_ufuncs_tags: dict[object, dict[str, Any]] = {}


class Deprecated(enum.Enum):
Expand All @@ -48,7 +48,7 @@ class Deprecated(enum.Enum):
DEPRECATED = Deprecated.DEPRECATED


def lazy_xp_function( # type: ignore[explicit-any]
def lazy_xp_function(
func: Callable[..., Any],
*,
allow_dask_compute: bool | int = False,
Expand Down Expand Up @@ -257,12 +257,12 @@ def xp(request, monkeypatch):
mod = cast(ModuleType, request.module)
mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]

def iter_tagged() -> ( # type: ignore[explicit-any]
def iter_tagged() -> (
Iterator[tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]]
):
for mod in mods:
for name, func in mod.__dict__.items():
tags: dict[str, Any] | None = None # type: ignore[explicit-any]
tags: dict[str, Any] | None = None
with contextlib.suppress(AttributeError):
tags = func._lazy_xp_function # pylint: disable=protected-access
if tags is None:
Expand Down Expand Up @@ -313,15 +313,17 @@ def __init__(self, max_count: int, msg: str): # numpydoc ignore=GL08
self.msg = msg

@override
def __call__(self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any) -> Any: # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
def __call__(
self, dsk: Graph, keys: Sequence[Key] | Key, **kwargs: Any
) -> Any: # numpydoc ignore=GL08
import dask

self.count += 1
# This should yield a nice traceback to the
# offending line in the user's code
assert self.count <= self.max_count, self.msg

return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
return dask.get(dsk, keys, **kwargs) # type: ignore[attr-defined] # pyright: ignore[reportPrivateImportUsage]


def _dask_wrap(
Expand Down Expand Up @@ -354,7 +356,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
# `pytest.raises` and `pytest.warns` to work as expected. Note that this would
# not work on scheduler='distributed', as it would not block.
arrays, rest = pickle_flatten(out, da.Array)
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage]
arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call] # pyright: ignore[reportPrivateImportUsage]
return pickle_unflatten(arrays, rest) # pyright: ignore[reportUnknownArgumentType]

return wrapper
10 changes: 5 additions & 5 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def at_op(
just a workaround for when one wants to apply jax.jit to `at()` directly,
which is not a common use case.
"""
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value))
return meth(y, copy=copy, xp=xp)


Expand Down Expand Up @@ -157,7 +157,7 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
"""
x = xp.asarray([1.0, 10.0, 20.0])
expect_copy = not is_writeable_array(x)
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[explicit-any]
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value))
with assert_copy(x, None, expect_copy):
_ = meth(2.0)

Expand All @@ -166,7 +166,7 @@ def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
# even if the arrays are writeable.
expect_copy = not is_writeable_array(x) or library is Backend.DASK
idx = xp.asarray([True, True, False])
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any]
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value))
with assert_copy(x, None, expect_copy):
_ = meth(2.0)

Expand All @@ -178,7 +178,7 @@ def test_copy_invalid():


def test_xp():
a = cast(Array, np.asarray([1, 2, 3])) # type: ignore[bad-cast]
a = cast(Array, np.asarray([1, 2, 3]))
_ = at(a, 0).set(4, xp=np)
_ = at(a, 0).add(4, xp=np)
_ = at(a, 0).subtract(4, xp=np)
Expand All @@ -190,7 +190,7 @@ def test_xp():


def test_alternate_index_syntax():
xp = cast(ModuleType, np) # pyright: ignore[reportInvalidCast]
xp = cast(ModuleType, np) # type: ignore[redundant-cast] # pyright: ignore[reportInvalidCast]
a = cast(Array, xp.asarray([1, 2, 3]))
xp_assert_equal(at(a, 0).set(4, copy=True), xp.asarray([4, 2, 3]))
xp_assert_equal(at(a)[0].set(4, copy=True), xp.asarray([4, 2, 3]))
Expand Down
5 changes: 1 addition & 4 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@

from .conftest import NUMPY_VERSION

# some xp backends are untyped
# mypy: disable-error-code=no-untyped-def

lazy_xp_function(apply_where)
lazy_xp_function(atleast_nd)
lazy_xp_function(cov)
Expand Down Expand Up @@ -213,7 +210,7 @@ def test_device(self, xp: ModuleType, device: Device):
p=st.floats(min_value=0, max_value=1),
data=st.data(),
)
def test_hypothesis( # type: ignore[explicit-any,decorated-any]
def test_hypothesis(
self,
n_arrays: int,
rng_seed: int,
Expand Down
13 changes: 6 additions & 7 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
def override(func):
return func

# mypy: disable-error-code=no-untyped-usage

T = TypeVar("T")

Expand Down Expand Up @@ -387,7 +386,7 @@ def test_static_hashable(self, jnp: ModuleType):
"""Static argument/return value is hashable, but not serializable"""

class C:
def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
def __reduce__(self) -> object: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride]
raise Exception()

@jax_autojit
Expand All @@ -399,12 +398,12 @@ def f(x: object) -> object:
assert out is inp

# Serializable opaque input contains non-serializable object plus array
inp = Wrapper((C(), jnp.asarray([1, 2])))
out = f(inp)
winp = Wrapper((C(), jnp.asarray([1, 2])))
out = f(winp)
assert isinstance(out, Wrapper)
assert out.x[0] is inp.x[0]
assert out.x[1] is not inp.x[1]
xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType]
assert out.x[0] is winp.x[0]
assert out.x[1] is not winp.x[1]
xp_assert_equal(out.x[1], winp.x[1]) # pyright: ignore[reportUnknownArgumentType]

def test_arraylikes_are_static(self):
pytest.importorskip("jax")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def f(x: Array) -> Array:
xp = array_namespace(x)
return xp.sum(x, axis=0) + x

x_np = cast(Array, np.arange(15).reshape(5, 3)) # type: ignore[bad-cast]
x_np = cast(Array, np.arange(15).reshape(5, 3))
expect = da.asarray(f(x_np))
x_da = da.asarray(x_np).rechunk(3)

Expand Down
Loading