Skip to content

Commit

Permalink
upgrade overrides
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Feb 8, 2024
1 parent ff2a12e commit 6afcede
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
19 changes: 16 additions & 3 deletions src/array_api_jax_compat/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def arange(
start: ArrayLike,
/,
stop: ArrayLike | None = None,
step: ArrayLike = 1,
step: ArrayLike | None = None,
*,
dtype: DType | None = None,
device: Device | None = None,
Expand All @@ -57,7 +57,7 @@ def arange(
start: ArrayLike,
stop: ArrayLike | None = None,
*,
step: ArrayLike = 1,
step: ArrayLike | None = None,
dtype: DType | None = None,
device: Device | None = None,
) -> ArrayLike:
Expand All @@ -70,14 +70,27 @@ def arange(
start: ArrayLike,
*,
stop: ArrayLike | None = None,
step: ArrayLike = 1,
step: ArrayLike | None = None,
dtype: DType | None = None,
device: Device | None = None,
) -> ArrayLike:
# dispatch on `start`, `stop`, and `step`
return arange(start, stop, step, dtype=dtype, device=device)


@dispatcher # type: ignore[misc]
def arange(
*,
start: ArrayLike,
stop: ArrayLike | None = None,
step: ArrayLike | None = None,
dtype: DType | None = None,
device: Device | None = None,
) -> ArrayLike | Value:
# dispatch on `start`, `stop`, and `step`
return arange(start, stop, step, dtype=dtype, device=device)


# =============================================================================


Expand Down
9 changes: 4 additions & 5 deletions src/array_api_jax_compat/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,28 @@
from jax.experimental import array_api
from jax.experimental.array_api._data_type_functions import FInfo, IInfo
from jaxtyping import ArrayLike
from quax import Value

from ._types import DType
from ._utils import quaxify


@quaxify
def astype(x: ArrayLike, dtype: DType, /, *, copy: bool = True) -> Value:
def astype(x: ArrayLike, dtype: DType, /, *, copy: bool = True) -> ArrayLike:
return array_api.astype(x, dtype, copy=copy)


@quaxify
def can_cast(from_: DType | Value, to: DType, /) -> bool:
def can_cast(from_: DType | ArrayLike, to: DType, /) -> bool:
return array_api.can_cast(from_, to)


@quaxify
def finfo(type: DType | Value, /) -> FInfo:
def finfo(type: DType | ArrayLike, /) -> FInfo:
return array_api.finfo(type)


@quaxify
def iinfo(type: DType | Value, /) -> IInfo:
def iinfo(type: DType | ArrayLike, /) -> IInfo:
return array_api.iinfo(type)


Expand Down

0 comments on commit 6afcede

Please sign in to comment.