Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🩹 add overload for ndarray.__matmul__ #286

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ _IntegerT = TypeVar("_IntegerT", bound=integer)
_SignedIntegerT = TypeVar("_SignedIntegerT", bound=signedinteger)
_UnsignedIntegerT = TypeVar("_UnsignedIntegerT", bound=unsignedinteger)
_CharT = TypeVar("_CharT", bound=character)
_IntegralT = TypeVar("_IntegralT", bound=bool_ | number | object_)

_NBitT = TypeVar("_NBitT", bound=NBitBase, default=Any)
_NBitT1 = TypeVar("_NBitT1", bound=NBitBase, default=Any)
Expand All @@ -611,6 +612,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)

_Array1D: TypeAlias = ndarray[tuple[int], dtype[_ScalarT]]

###
# Type Aliases (for internal use only)

Expand Down Expand Up @@ -2531,8 +2534,9 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
@overload
def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...

# TODO(jorenham): Support the "1d @ 1d -> scalar" case
# https://github.com/numpy/numtype/issues/197
#
@overload
def __matmul__(self: _Array1D[_IntegralT], rhs: _Array1D[_IntegralT], /) -> _IntegralT: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
Copy link
Member

@jorenham jorenham Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the overlapping overloads are actually an issue in this case. In several cases mypy incorrectly reports overload-overlap, but when pyright reports it, it's usually for a good reason. And if both report it, then only in very rare circumstances can it be safe to ignore (and this isn't one of those).

For example, let's assume that self is 1d and float64, and rhs is the union of a 1d and 2d bool arrays. Then the first overload does not match, because that requires rhs to only be 1d, but rhs could be both 1d or 2d (as it's a union type). But the second overload does match, because there, the shape of rhs doesn't matter. The return type is therefore inferred as NDArray[float64]. But that's incompatible with the return type of the first overload, because it doesn't account for the possibility that rhs is 1d, which would result in a scalar.

It's very tricky to solve this unfortunately, which is why I made a separate issue for it. I've been building the _numtype internal type-check-only package so that we can deal with situation like these, but it's changing quite fast, so I'm a bit hesitant to start using it in something as important as ndarray at this point.
So until _numtype has a somewhat stable API, and is actually tested (which is isn't right now), it's probably for the best to put this PR in the freezer for the time being.

This shows one of the (many) reasons why shape-typing is so difficult, and why it has been taking so long to make progress on. We'll get there eventually, but I'm careful not to rush into it.

I realize that __matmul__ is currently also incorrect, so it might seem weird that I don't want to fix it for this specific situation, as it seems like a net win. But the problem with overlapping overloads is, especially in case of overloads, a very pernicious one, and it could (and almost certainly will) lead to unexpected issues in which will be very difficult to debug. It's one of the most complicated parts of the Python typing system, and I probably don't fully understand it myself, so I'll just leave it at that for now.

In numpy/numpy#27032 (comment) explain it in a bit more detail, but even so, that only scratches the surface of this can of worms I'm afraid 😅.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation! Totally get the challenges with overlapping overloads and shape typing in ndarray. It makes sense to wait until _numtype is more stable before moving forward. Looking forward to the progress! 🚀

@overload
def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
@overload
Expand Down