You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
I'm wondering whether it is feasible for my pallas kernel to update on input refs? For example, if i want to both read and write to the same tensor via my pallas kernel, how can I achieve that?
I'm putting down a minimal script and it seems that it doesn't work as I expect.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o_ref[:] = x + y
def add_inplace_kernel(x_ref, y_ref, o_ref):
# In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
x = x_ref[:]
y = y_ref[:]
o = x + y
y_ref[:] = o
x, y = jnp.arange(8), jnp.arange(8, 16)
print("=======regular add=========")
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
print("x ", x)
print("y ", y)
o = add(x, y)
print("o ", o)
print("=======regular add=========")
print("=======inplace add=========")
inplace_add = pl.pallas_call(add_inplace_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
print("x ", x)
print("y ", y)
o_dummy = inplace_add(x, y)
print("after inplace add y ", y)
print("o_dummy ", o_dummy)
print("=======inplace add=========")
This doesn't work as intended because x, y (the JAX tensors) live in HBM, and Pallas will copy them to the innermost memory hierarchy (e.g. VMEM on TPUs) before invoking the kernel. Therefore, when you modify y_ref you're only modifying the copy and it's not updating the actual y that's resident in HBM.
Try aliasing the input and output to the same ref using the input_output_aliases argument to pallas call. In your case for the in-place add, you need to use:
Description
Hello,
I'm wondering whether it is feasible for my pallas kernel to update on input refs? For example, if i want to both read and write to the same tensor via my pallas kernel, how can I achieve that?
I'm putting down a minimal script and it seems that it doesn't work as I expect.
Thanks!
System info (python version, jaxlib version, accelerator, etc.)
From this script I see
but I'm trying to get the
inplace_add
to return mey [ 8 10 12 14 16 18 20 22]
. Is that possible?The text was updated successfully, but these errors were encountered: