Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for partitioning/sharded data with Pallas kernels? #72

Open
G-Levine opened this issue Mar 4, 2024 · 10 comments
Open

Support for partitioning/sharded data with Pallas kernels? #72

G-Levine opened this issue Mar 4, 2024 · 10 comments

Comments

@G-Levine
Copy link

G-Levine commented Mar 4, 2024

I'm trying to train a model with a custom linear attention kernel I wrote in Pallas, but the following issue is occurring (only happens when the input data is sharded across multiple TPU devices).

jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Mosaic kernels cannot be automatically partitioned. Please wrap the call in a shard_map or xmap.

Here's the code that I'm trying to run: https://github.com/G-Levine/levanter/blob/9e78ab17e416d5e471f27255d13888d5fb98e632/src/levanter/models/linear_attention.py

Is there a recommended way to achieve this with Haliax? I tried to find examples of people using the Pallas FlashAttention kernel with Haliax/Levanter, but it appears nobody has tried this yet. It seems like an important use case to support, for anyone who wants to efficiently train transformer models on multiple TPUs.

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

Hey,

I'm not aware of anybody trying yet. Do you have an example of it working in e.g. flax? My guess is that we need to use shardmap at they say, but I don't have much experience iwth that yet either.

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

OK, the best example I see if from MaxText https://github.com/google/maxtext/blob/10a7c473e9feb1107894e7588b283b1bcfcbd679/MaxText/layers/attentions.py#L213

I think te basic idea to get a PSpec for each input array (and similarly with the the expected output shape) and then call shard_map(kernel), and then just double check that there's no sharding of axes that the kernel assumes to be single-device?

So, I think what you'll want to do is to call hax.partitioning.pspec_for_axis(a.axes, (mapping)) for every named array argument (and figure something out for non-named args I guess). Then, for each axis that needs to be on a single device (e.g. non-batch axes if there's no communication), raise an error if it's sharded in the pspec.

@G-Levine
Copy link
Author

G-Levine commented Mar 4, 2024

Thanks, that helps a lot. I'm able to call the kernel without errors now. However, I'm still trying to figure out how to manipulate the kernel output (a plain Jax array) back into a Haliax named array with the correct sharding. Here's my current code:

def linear_attention(
    query: NamedArray,
    key: NamedArray,
    value: NamedArray,
) -> NamedArray:
    @functools.partial(
        shard_map.shard_map,
        mesh=hax.partitioning._get_mesh(),
        in_specs=(
            hax.partitioning.pspec_for_axis(query.axes),
            hax.partitioning.pspec_for_axis(key.axes),
            hax.partitioning.pspec_for_axis(value.axes),
        ),
        out_specs=hax.partitioning.pspec_for_axis(value.axes),
        check_rep=False,
    )
    def attn_sharded(query, key, value):
        q = query.array
        k = key.array
        v = value.array
        kv_carry = jnp.zeros_like(k)
        k_carry = jnp.zeros_like(k)
        y, _, _ = attn(q, k, v, kv_carry, k_carry)
        named_y = hax.named(y, tuple((axis.name for axis in value.axes)))
        return named_y
    return attn_sharded(query, key, value)

When I try to use the output of this function in the model, it results in this error: ValueError: Shape of underlying array (256, 1024, 768) does not match shape of axes (Axis(name='batch', size=64), Axis(name='position', size=1024), Axis(name='embed', size=768)). I assume this means the sharding information was dropped somewhere (it's being sharded across 4 devices, so 64/256 of the batch axis is what's expected on one device).
It's not clear to me how the MaxText example handles the output sharding (it looks like it just returns the output of the kernel directly?)

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

the issue is that the output array is the "local" array inside the shard map, so Haliax infers that batch is 64, but outside the shard map the raw jax array is concatenated/global, but JAX doesn't know about Haliax's arrays so the axis sizes don't change (I should change the way Haliax works to make this easier...)

The easiest thing to do is return a plain jax array from attn_sharded and then wrap the array before returning from linear attention.

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

(I'm glad this turned out to be relatively straightforward!)

@G-Levine
Copy link
Author

G-Levine commented Mar 4, 2024

Great, it's all working now! Here's the final code I ended up with.

def linear_attention(
    query: NamedArray,
    key: NamedArray,
    value: NamedArray,
) -> NamedArray:
    @functools.partial(
        shard_map.shard_map,
        mesh=hax.partitioning._get_mesh(),
        in_specs=(
            hax.partitioning.pspec_for_axis(query.axes),
            hax.partitioning.pspec_for_axis(key.axes),
            hax.partitioning.pspec_for_axis(value.axes),
        ),
        out_specs=hax.partitioning.pspec_for_axis(value.axes),
        check_rep=False,
    )
    def attn_sharded(query, key, value):
        q = query.array
        k = key.array
        v = value.array
        kv_carry = jnp.zeros_like(k)
        k_carry = jnp.zeros_like(k)
        y, _, _ = attn(q, k, v, kv_carry, k_carry)
        return y
    y = attn_sharded(query, key, value)
    return hax.named(y, value.axes)

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

Sweet! I'll leave this open just as a "make it easy for people to do this"/make a tutorial issue.

@dlwh
Copy link
Member

dlwh commented Mar 4, 2024

also, could you let me know what kind of speedup you get? We can try to prioritize getting it into Levanter if it's nontrivial

@G-Levine
Copy link
Author

G-Levine commented Mar 4, 2024

For the Pallas linear attention kernel? My testing so far is showing a very significant speedup across all sequence lengths. (This is the runtime for the forward + backward pass).
attention_runtime

@dlwh
Copy link
Member

dlwh commented Mar 5, 2024

That's nice! I actually just meant Pallas flash attention vs pure JAX attention on TPU

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants