Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] Unable to Modify Input Ref in Pallas Kernel #24656

Open
shangz-ai opened this issue Nov 1, 2024 · 3 comments
Open

[Pallas] Unable to Modify Input Ref in Pallas Kernel #24656

shangz-ai opened this issue Nov 1, 2024 · 3 comments
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@shangz-ai
Copy link

shangz-ai commented Nov 1, 2024

Description

Hello,
I'm wondering whether it is feasible for my pallas kernel to update on input refs? For example, if i want to both read and write to the same tensor via my pallas kernel, how can I achieve that?
I'm putting down a minimal script and it seems that it doesn't work as I expect.
Thanks!

System info (python version, jaxlib version, accelerator, etc.)

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o_ref[:] = x + y

def add_inplace_kernel(x_ref, y_ref, o_ref):
  # In this code, `x_ref`, `y_ref` and `o_ref` are (8,)-shaped `Ref`s
  x = x_ref[:]
  y = y_ref[:]
  o = x + y
  y_ref[:] = o

x, y = jnp.arange(8), jnp.arange(8, 16)

print("=======regular add=========")
add = pl.pallas_call(add_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
print("x ", x)
print("y ", y)
o = add(x, y)
print("o ", o)
print("=======regular add=========")

print("=======inplace add=========")
inplace_add = pl.pallas_call(add_inplace_kernel, out_shape=jax.ShapeDtypeStruct((8,), jnp.int32))
print("x ", x)
print("y ", y)
o_dummy = inplace_add(x, y)
print("after inplace add y ", y)
print("o_dummy ", o_dummy)
print("=======inplace add=========")

From this script I see

=======regular add=========
x  [0 1 2 3 4 5 6 7]
y  [ 8  9 10 11 12 13 14 15]
o  [ 8 10 12 14 16 18 20 22]
=======regular add=========
=======inplace add=========
x  [0 1 2 3 4 5 6 7]
y  [ 8  9 10 11 12 13 14 15]
after inplace add y  [ 8  9 10 11 12 13 14 15]
o_dummy  [0 0 0 0 0 0 0 0]
=======inplace add=========

but I'm trying to get the inplace_add to return me y [ 8 10 12 14 16 18 20 22]. Is that possible?

@shangz-ai shangz-ai added the bug Something isn't working label Nov 1, 2024
@shangz-ai
Copy link
Author

Also I found this issue quite relevant #22276

@shangz-ai shangz-ai changed the title [Pallas] Unable to Update Input Ref in Pallas Kernel [Pallas] Unable to Modify Input Ref in Pallas Kernel Nov 1, 2024
@justinjfu
Copy link
Collaborator

justinjfu commented Nov 1, 2024

This doesn't work as intended because x, y (the JAX tensors) live in HBM, and Pallas will copy them to the innermost memory hierarchy (e.g. VMEM on TPUs) before invoking the kernel. Therefore, when you modify y_ref you're only modifying the copy and it's not updating the actual y that's resident in HBM.

Try aliasing the input and output to the same ref using the input_output_aliases argument to pallas call. In your case for the in-place add, you need to use:

inplace_add = pl.pallas_call(add_inplace_kernel,
  out_shape=jax.ShapeDtypeStruct((8,), jnp.int32),
  input_output_aliases={1:0})

Pallas will copy outputs back to HBM, so this will trigger Pallas to copy your updates to y_ref back to HBM.

@justinjfu justinjfu added the pallas Issues pertaining to Pallas (GPU or TPU) label Nov 1, 2024
@shangz-ai
Copy link
Author

Thanks a lot! I think it is exactly what I want. Let me try to work on my real kernel to see if I can get what I need.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants