Skip to content

Commit

Permalink
refactor: set numpy dispatcher to plum default (#83)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Sep 17, 2024
1 parent bc2a4f4 commit 3be8995
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/quaxed/numpy/_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@
import jax
import jax.numpy as jnp
from jaxtyping import ArrayLike
from plum import dispatch
from quax import Value

from quaxed._types import DType
from quaxed._utils import quaxify

from ._dispatch import dispatcher

T = TypeVar("T")

# =============================================================================


@dispatcher
@dispatch
def arange(
start: ArrayLike,
stop: ArrayLike | None,
Expand All @@ -45,7 +44,7 @@ def arange(
return jnp.arange(start, stop, step, dtype=dtype)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def arange(
start: ArrayLike,
stop: ArrayLike | None,
Expand All @@ -58,7 +57,7 @@ def arange(
return arange(start, stop, step, dtype=dtype)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def arange(
start: ArrayLike,
/,
Expand All @@ -71,7 +70,7 @@ def arange(
return arange(start, stop, step, dtype=dtype)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def arange(
*,
start: ArrayLike,
Expand Down Expand Up @@ -101,7 +100,7 @@ def asarray(
# =============================================================================


@dispatcher # type: ignore[misc]
@dispatch # type: ignore[misc]
def empty_like(
prototype: ArrayLike,
/,
Expand All @@ -115,7 +114,7 @@ def empty_like(
# =============================================================================


@dispatcher
@dispatch
def full(
shape: tuple[int, ...] | int,
fill_value: ArrayLike,
Expand All @@ -125,7 +124,7 @@ def full(
return jnp.full(shape, fill_value, dtype=dtype)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def full(
shape: tuple[int, ...] | int,
*,
Expand All @@ -138,7 +137,7 @@ def full(
# =============================================================================


@dispatcher
@dispatch
def full_like(
x: ArrayLike,
/,
Expand All @@ -150,7 +149,7 @@ def full_like(
return jnp.full_like(x, fill_value, dtype=dtype, shape=shape)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def full_like(
x: ArrayLike,
*,
Expand All @@ -167,7 +166,7 @@ def full_like(
# =============================================================================


@dispatcher
@dispatch
def linspace( # noqa: PLR0913
start: ArrayLike,
stop: ArrayLike,
Expand All @@ -184,7 +183,7 @@ def linspace( # noqa: PLR0913
)


@dispatcher # type: ignore[no-redef]
@dispatch # type: ignore[no-redef]
def linspace( # noqa: PLR0913
start: ArrayLike,
stop: ArrayLike,
Expand Down Expand Up @@ -215,7 +214,7 @@ def meshgrid(
# =============================================================================


@dispatcher # type: ignore[misc]
@dispatch # type: ignore[misc]
def ones_like(
x: ArrayLike, /, *, dtype: DType | None = None, shape: tuple[int, ...] | None = None
) -> ArrayLike:
Expand Down Expand Up @@ -244,7 +243,7 @@ def triu(x: ArrayLike, /, *, k: int = 0) -> ArrayLike:

# @partial(jax.jit, static_argnames=("dtype", "device"))
# @quaxify
@dispatcher # type: ignore[misc]
@dispatch # type: ignore[misc]
def zeros_like(
x: ArrayLike,
/,
Expand Down

0 comments on commit 3be8995

Please sign in to comment.