diff --git a/examples/pallas/blocksparse_matmul.py b/examples/pallas/blocksparse_matmul.py index cf2936e1..eb77c6cd 100644 --- a/examples/pallas/blocksparse_matmul.py +++ b/examples/pallas/blocksparse_matmul.py @@ -61,9 +61,9 @@ def body(k, acc): @jax.tree_util.register_pytree_node_class class BlockELL: - blocks: jnp.array # float32[n_rows, n_blocks, *block_size] - blocks_per_row: jnp.array # int32[n_rows, n_blocks] - indices: jnp.array # int32[n_rows, max_num_blocks_per_row, 2] + blocks: jnp.ndarray # float32[n_rows, n_blocks, *block_size] + blocks_per_row: jnp.ndarray # int32[n_rows, n_blocks] + indices: jnp.ndarray # int32[n_rows, max_num_blocks_per_row, 2] shape: Tuple[int, int] # (n_rows * block_size[0], n_cols * block_size[1]) ndim: int = property(lambda self: len(self.shape))