Skip to content

Commit 2d009f2

Browse files
committed
🩹 add overload for ndarray.__matmul__
1 parent 51a21ef commit 2d009f2

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

‎src/numpy-stubs/__init__.pyi

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import abc
23
import ctypes as ct
34
import datetime as dt
@@ -611,6 +612,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
611612
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
612613
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
613614

615+
_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]]
616+
614617
###
615618
# Type Aliases (for internal use only)
616619

@@ -2534,7 +2537,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25342537
# TODO(jorenham): Support the "1d @ 1d -> scalar" case
25352538
# https://github.com/numpy/numtype/issues/197
25362539
@overload
2537-
def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
2540+
def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
25382541
@overload
25392542
def __matmul__(self: NDArray[bool_], rhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
25402543
@overload
@@ -2566,12 +2569,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25662569
@overload
25672570
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
25682571
@overload
2569-
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
2572+
def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
25702573
@overload
25712574
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
25722575

25732576
# keep in sync with __matmul__
25742577
@overload
2578+
def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
2579+
@overload
25752580
def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
25762581
@overload
25772582
def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
@@ -2604,7 +2609,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
26042609
@overload
26052610
def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
26062611
@overload
2607-
def __rmatmul__(self: NDArray[object_], lhs: object, /) -> NDArray[object_]: ...
2612+
def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
26082613
@overload
26092614
def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
26102615

0 commit comments

Comments
 (0)