-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid padding computation with cu_seqlens
#1228
Comments
Yes I think that should work. You should test that still |
Thanks! I have tested the kernel and it does work. However, the padding elements may be uninitialized, resulting in NaN/inf in the forward and backward passes. Can we include a fix to simply zero these elements? |
BTW, here is the code used for testing: from typing import Any
import torch
from tqdm import tqdm
from flash_attn import flash_attn_varlen_func
def test_flash_attn_padding(
seed: int = 0,
test_rounds: int = 10,
num_heads: int = 8,
head_size: int = 64,
seq_len: int = 160,
batch_size: int = 131_072,
dtype: Any = torch.bfloat16
):
torch.manual_seed(seed)
torch.set_default_device("cuda")
# Construct testdata
q = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
k = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
v = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
seqlens = torch.cat([
torch.full((batch_size // seq_len, ), seq_len, dtype=torch.int32),
torch.full((1, ), batch_size % seq_len, dtype=torch.int32)
])
cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(-1, dtype=seqlens.dtype), (1, 0))
max_seqlen = seqlens.max()
# Multiple rounds so that torch.empty() might be filled with random value
for round in tqdm(range(test_rounds)):
# Fwd
gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
assert torch.allclose(nopad_out[:cu_seqlens[-2]], gt_out[:cu_seqlens[-2]])
# assert torch.allclose(nopad_out[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
# Bwd
# ground truth
dgrad = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
(dgrad * gt_out).sum().backward()
gt_dq = q.grad
gt_dk = k.grad
gt_dv = v.grad
q.grad = None
k.grad = None
v.grad = None
# unpadded
nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
(dgrad * nopad_out).sum().backward()
assert torch.allclose(q.grad[:cu_seqlens[-2]], gt_dq[:cu_seqlens[-2]])
# assert torch.allclose(q.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
assert torch.allclose(k.grad[:cu_seqlens[-2]], gt_dk[:cu_seqlens[-2]])
# assert torch.allclose(k.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
assert torch.allclose(v.grad[:cu_seqlens[-2]], gt_dv[:cu_seqlens[-2]])
# assert torch.allclose(v.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype)) # Might be empty here
q.grad = None
k.grad = None
v.grad = None
if __name__ == "__main__":
test_flash_attn_padding() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To work with
torch.compile
which is more efficient on static shapes, I pad some tokens at the end to make the shape ofq,k,v
static, e.g.[N, D]
.Can I set the last element in
cu_seqlens
of varlen API to be less thanN
to avoid computing the padding? Also, is the backward pass accurate in this case?The text was updated successfully, but these errors were encountered: