Skip to content

Commit 8cc63f9

Browse files
committed
🩹 add overload for ndarray.__matmul__
1 parent 51a21ef commit 8cc63f9

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

‎src/numpy-stubs/__init__.pyi

+10-4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ from typing import (
2727
)
2828
from typing_extensions import Buffer, CapsuleType, LiteralString, Never, Protocol, Self, TypeVar, Unpack, deprecated, override
2929

30+
import numpy as np
31+
3032
from . import (
3133
__config__ as __config__,
3234
_array_api_info as _array_api_info,
@@ -611,6 +613,8 @@ _DT64ItemT = TypeVar("_DT64ItemT", bound=dt.date | int | None)
611613
_DT64ItemT_co = TypeVar("_DT64ItemT_co", bound=dt.date | int | None, default=dt.date | int | None, covariant=True)
612614
_TD64UnitT = TypeVar("_TD64UnitT", bound=_TD64Unit, default=_TD64Unit)
613615

616+
_Array1D: TypeAlias = np.ndarray[tuple[int], np.dtype[_ScalarT]]
617+
614618
###
615619
# Type Aliases (for internal use only)
616620

@@ -2531,8 +2535,8 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25312535
@overload
25322536
def __imul__(self: NDArray[object_], rhs: object, /) -> ndarray[_ShapeT_co, _DTypeT_co]: ...
25332537

2534-
# TODO(jorenham): Support the "1d @ 1d -> scalar" case
2535-
# https://github.com/numpy/numtype/issues/197
2538+
@overload
2539+
def __matmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
25362540
@overload
25372541
def __matmul__(self: NDArray[_NumberT], rhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
25382542
@overload
@@ -2566,12 +2570,14 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
25662570
@overload
25672571
def __matmul__(self: NDArray[bool_ | number], rhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
25682572
@overload
2569-
def __matmul__(self: NDArray[object_], rhs: object, /) -> NDArray[object_]: ...
2573+
def __matmul__(self: NDArray[object_], rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
25702574
@overload
25712575
def __matmul__(self, rhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
25722576

25732577
# keep in sync with __matmul__
25742578
@overload
2579+
def __rmatmul__(self: _Array1D[_ScalarT], rhs: _Array1D[_ScalarT], /) -> _ScalarT: ...
2580+
@overload
25752581
def __rmatmul__(self: NDArray[_NumberT], lhs: _ArrayLikeBool_co, /) -> NDArray[_NumberT]: ...
25762582
@overload
25772583
def __rmatmul__(self: NDArray[bool_], lhs: _ArrayLike[_NumberT], /) -> NDArray[_NumberT]: ...
@@ -2604,7 +2610,7 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
26042610
@overload
26052611
def __rmatmul__(self: NDArray[bool_ | number], lhs: _ArrayLikeNumber_co, /) -> NDArray[Incomplete]: ...
26062612
@overload
2607-
def __rmatmul__(self: NDArray[object_], lhs: object, /) -> NDArray[object_]: ...
2613+
def __rmatmul__(self: NDArray[object_], lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
26082614
@overload
26092615
def __rmatmul__(self, lhs: _ArrayLikeObject_co, /) -> NDArray[object_]: ...
26102616

0 commit comments

Comments
 (0)