diff --git a/src/array_api_jax_compat/_creation_functions.py b/src/array_api_jax_compat/_creation_functions.py index 2f14b08..925d3df 100644 --- a/src/array_api_jax_compat/_creation_functions.py +++ b/src/array_api_jax_compat/_creation_functions.py @@ -44,7 +44,7 @@ def arange( start: ArrayLike, /, stop: ArrayLike | None = None, - step: ArrayLike = 1, + step: ArrayLike | None = None, *, dtype: DType | None = None, device: Device | None = None, @@ -57,7 +57,7 @@ def arange( start: ArrayLike, stop: ArrayLike | None = None, *, - step: ArrayLike = 1, + step: ArrayLike | None = None, dtype: DType | None = None, device: Device | None = None, ) -> ArrayLike: @@ -70,7 +70,7 @@ def arange( start: ArrayLike, *, stop: ArrayLike | None = None, - step: ArrayLike = 1, + step: ArrayLike | None = None, dtype: DType | None = None, device: Device | None = None, ) -> ArrayLike: @@ -78,6 +78,19 @@ def arange( return arange(start, stop, step, dtype=dtype, device=device) +@dispatcher # type: ignore[misc] +def arange( + *, + start: ArrayLike, + stop: ArrayLike | None = None, + step: ArrayLike | None = None, + dtype: DType | None = None, + device: Device | None = None, +) -> ArrayLike | Value: + # dispatch on `start`, `stop`, and `step` + return arange(start, stop, step, dtype=dtype, device=device) + + # ============================================================================= diff --git a/src/array_api_jax_compat/_data_type_functions.py b/src/array_api_jax_compat/_data_type_functions.py index ac90fa0..17146bc 100644 --- a/src/array_api_jax_compat/_data_type_functions.py +++ b/src/array_api_jax_compat/_data_type_functions.py @@ -4,29 +4,28 @@ from jax.experimental import array_api from jax.experimental.array_api._data_type_functions import FInfo, IInfo from jaxtyping import ArrayLike -from quax import Value from ._types import DType from ._utils import quaxify @quaxify -def astype(x: ArrayLike, dtype: DType, /, *, copy: bool = True) -> Value: +def astype(x: ArrayLike, dtype: DType, /, *, copy: bool = True) -> ArrayLike: return array_api.astype(x, dtype, copy=copy) @quaxify -def can_cast(from_: DType | Value, to: DType, /) -> bool: +def can_cast(from_: DType | ArrayLike, to: DType, /) -> bool: return array_api.can_cast(from_, to) @quaxify -def finfo(type: DType | Value, /) -> FInfo: +def finfo(type: DType | ArrayLike, /) -> FInfo: return array_api.finfo(type) @quaxify -def iinfo(type: DType | Value, /) -> IInfo: +def iinfo(type: DType | ArrayLike, /) -> IInfo: return array_api.iinfo(type)