Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 555263389
  • Loading branch information
hawkinsp authored and The jax_triton Authors committed Aug 9, 2023
1 parent 75d47fb commit a25b1ba
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/pallas/blocksparse_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def body(k, acc):

@jax.tree_util.register_pytree_node_class
class BlockELL:
blocks: jnp.array # float32[n_rows, n_blocks, *block_size]
blocks_per_row: jnp.array # int32[n_rows, n_blocks]
indices: jnp.array # int32[n_rows, max_num_blocks_per_row, 2]
blocks: jnp.ndarray # float32[n_rows, n_blocks, *block_size]
blocks_per_row: jnp.ndarray # int32[n_rows, n_blocks]
indices: jnp.ndarray # int32[n_rows, max_num_blocks_per_row, 2]
shape: Tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1])

ndim: int = property(lambda self: len(self.shape))
Expand Down

0 comments on commit a25b1ba

Please sign in to comment.