Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid padding computation with cu_seqlens #1228

Open
imoneoi opened this issue Sep 14, 2024 · 3 comments
Open

Avoid padding computation with cu_seqlens #1228

imoneoi opened this issue Sep 14, 2024 · 3 comments

Comments

@imoneoi
Copy link

imoneoi commented Sep 14, 2024

To work with torch.compile which is more efficient on static shapes, I pad some tokens at the end to make the shape of q,k,v static, e.g. [N, D].

Can I set the last element in cu_seqlens of varlen API to be less than N to avoid computing the padding? Also, is the backward pass accurate in this case?

@tridao
Copy link
Contributor

tridao commented Sep 14, 2024

Yes I think that should work. You should test that still

@imoneoi
Copy link
Author

imoneoi commented Sep 15, 2024

Thanks! I have tested the kernel and it does work. However, the padding elements may be uninitialized, resulting in NaN/inf in the forward and backward passes. Can we include a fix to simply zero these elements?

@imoneoi
Copy link
Author

imoneoi commented Sep 15, 2024

BTW, here is the code used for testing:

from typing import Any
import torch

from tqdm import tqdm
from flash_attn import flash_attn_varlen_func


def test_flash_attn_padding(
    seed: int = 0,
    test_rounds: int = 10,
    num_heads: int = 8,
    head_size: int = 64,
    seq_len: int = 160,
    batch_size: int = 131_072,

    dtype: Any = torch.bfloat16
):
    torch.manual_seed(seed)
    torch.set_default_device("cuda")

    # Construct testdata
    q = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    k = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    v = torch.randn((batch_size, num_heads, head_size), dtype=dtype)

    seqlens = torch.cat([
        torch.full((batch_size // seq_len, ), seq_len, dtype=torch.int32),
        torch.full((1, ), batch_size % seq_len, dtype=torch.int32)
    ])

    cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(-1, dtype=seqlens.dtype), (1, 0))
    max_seqlen = seqlens.max()

    # Multiple rounds so that torch.empty() might be filled with random value
    for round in tqdm(range(test_rounds)):
        # Fwd
        gt_out    = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)

        assert torch.allclose(nopad_out[:cu_seqlens[-2]], gt_out[:cu_seqlens[-2]])
        # assert torch.allclose(nopad_out[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        # Bwd
        # ground truth
        dgrad = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
        q.requires_grad_()
        k.requires_grad_()
        v.requires_grad_()

        gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        (dgrad * gt_out).sum().backward()

        gt_dq = q.grad
        gt_dk = k.grad
        gt_dv = v.grad
        q.grad = None
        k.grad = None
        v.grad = None

        # unpadded
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen) 
        (dgrad * nopad_out).sum().backward()

        assert torch.allclose(q.grad[:cu_seqlens[-2]], gt_dq[:cu_seqlens[-2]])
        # assert torch.allclose(q.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(k.grad[:cu_seqlens[-2]], gt_dk[:cu_seqlens[-2]])
        # assert torch.allclose(k.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(v.grad[:cu_seqlens[-2]], gt_dv[:cu_seqlens[-2]])
        # assert torch.allclose(v.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        q.grad = None
        k.grad = None
        v.grad = None

if __name__ == "__main__":
    test_flash_attn_padding()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants