Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

is flash_attn_with_kvcache() supposed to work for seqlen > 1 ? #1402

Open
vince62s opened this issue Dec 20, 2024 · 0 comments
Open

is flash_attn_with_kvcache() supposed to work for seqlen > 1 ? #1402

vince62s opened this issue Dec 20, 2024 · 0 comments

Comments

@vince62s
Copy link

vince62s commented Dec 20, 2024

I am trying to figure out if flash_attn_with_kvcache() output matches sdpa.

I am running this script without Rotary embeddings first:

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from flash_attn import flash_attn_with_kvcache

batchsize=4
heads=16
dim_per_head = 16

for seqlen in [1, 2]:
    query = torch.rand((batchsize, heads, seqlen, dim_per_head), device=torch.device("cuda")).half()
    key = query
    value = query

    ###########  Compute attn with SDPA
    with sdpa_kernel([SDPBackend.MATH]):
        attn_output1 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
        attn_output2 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
        attn_output3 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
        attn_output4 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    
    ########### flash kvcache
    kcache = torch.zeros_like(key)
    vcache = torch.zeros_like(value)

    attn_output5 = flash_attn_with_kvcache(
                        query.transpose(1, 2),
                        kcache.transpose(1, 2),
                        vcache.transpose(1, 2),
                        key.transpose(1, 2),
                        value.transpose(1, 2),
                        cache_seqlens=0,
                    ).transpose(1, 2)

    print("======>  results for sequence length: ", seqlen)
    for atol in [1e-3, 1e-4, 1e-5]:
        print("atol: ", atol)
        print("MATH vs EFFICIENT: ", torch.allclose(attn_output1, attn_output2, atol=atol))
        print("EFFICIENT vs SDPAFLASH: ", torch.allclose(attn_output2, attn_output3, atol=atol))
        print("MATH vs SDPAFLASH: ", torch.allclose(attn_output1, attn_output3, atol=atol))
        print("MATH vs CUDNN: ", torch.allclose(attn_output1, attn_output4, atol=atol))                
        print("EFFICIENT vs CUDNN: ", torch.allclose(attn_output2, attn_output4, atol=atol))
        print("SDPAFLASH vs CUDNN: ", torch.allclose(attn_output3, attn_output4, atol=atol))
        print("EFFICIENT vs FLASHCACHE: ", torch.allclose(attn_output2, attn_output5, atol=atol))
        print("SDPAFLASH vs FLASHCACHE: ", torch.allclose(attn_output3, attn_output5, atol=atol))
        print()
    print()                                

The output is as follow:

======>  results for sequence length:  1
atol:  0.001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  0.0001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  1e-05
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True


======>  results for sequence length:  2
atol:  0.001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  0.0001
MATH vs EFFICIENT:  False
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  False
MATH vs CUDNN:  False
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  1e-05
MATH vs EFFICIENT:  False
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  False
MATH vs CUDNN:  False
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

We can see that for seqlen=1 all outputs match equal or beyond precision 1e-5
When we switch to seqlen=2 (or more):
at atol 1e-3 all match
starting at atol 1e-4: all but MATH match. The reason being probably that MATH is computed at float32
If I switch query to float32, then MATH and EFFICIENT match

Since SDPAFLASH and FLASH match, let say it's just a precision issue.

NOW.
I am trying to do the same with Rotary embeddings:

Using the following code:

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from flash_attn import flash_attn_with_kvcache

maxseqlen=512
batchsize=2
heads=16
dim_per_head = 16

rotary_dim = 16
rotary_theta = 10000
inv_freq = 1.0 / (
    rotary_theta ** (torch.arange(0, rotary_dim, 2).float() / rotary_dim)
).to(torch.device("cuda"))
tmax = torch.arange(maxseqlen, device=torch.device("cuda"))
rope = torch.outer(tmax, inv_freq)

cos = torch.cos(rope)
sin = torch.sin(rope)
cos = torch.cat((cos, cos), dim=-1)  # Double the size by repeating `cos`
sin = torch.cat((sin, sin), dim=-1)  # Double the size by repeating `sin`
position_embeddings = (cos, sin)

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)
    
def apply_rotary_emb(query, key, rope):
    # now rope is a tuple (cos, sin)
    cos, sin = rope
    q_embed = (query * cos) + (rotate_half(query) * sin)
    k_embed = (key * cos) + (rotate_half(key) * sin)
    return q_embed.type_as(query), k_embed.type_as(key)


for seqlen in [1, 2]:
    query = torch.rand((batchsize, heads, seqlen, dim_per_head), device=torch.device("cuda")).half()
    key = query
    value = query

    ###########  Compute attn with SDPA
    start_pos = 0
    cos = position_embeddings[0][start_pos : start_pos + seqlen].to(query.dtype)
    sin = position_embeddings[1][start_pos : start_pos + seqlen].to(query.dtype)
    query, key = apply_rotary_emb(query, key, (cos, sin))

    with sdpa_kernel([SDPBackend.MATH]):
        attn_output1 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
        attn_output2 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.FLASH_ATTENTION]):
        attn_output3 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
        attn_output4 = scaled_dot_product_attention(query, key, value, None, 0.0, is_causal=False)
    
    ########### flash kvcache
    kcache = torch.zeros_like(key)
    vcache = torch.zeros_like(value)

    # flash expects dim1 of size half the rotary_dim
    cos = position_embeddings[0][:, : cos.size(1) // 2].to(query.dtype)
    sin = position_embeddings[1][:, : sin.size(1) // 2].to(query.dtype)

    attn_output5 = flash_attn_with_kvcache(
                        query.transpose(1, 2),
                        kcache.transpose(1, 2),
                        vcache.transpose(1, 2),
                        key.transpose(1, 2),
                        value.transpose(1, 2),
                        rotary_cos=cos,
                        rotary_sin=sin,
                        cache_seqlens=0,
                        rotary_interleaved=False,
                    ).transpose(1, 2)

    print("results for sequence length: ", seqlen)
    for atol in [1e-3, 1e-4, 1e-5]:
        print("atol: ", atol)
        print("MATH vs EFFICIENT: ", torch.allclose(attn_output1, attn_output2, atol=atol))
        print("EFFICIENT vs SDPAFLASH: ", torch.allclose(attn_output2, attn_output3, atol=atol))
        print("MATH vs SDPAFLASH: ", torch.allclose(attn_output1, attn_output3, atol=atol))
        print("MATH vs CUDNN: ", torch.allclose(attn_output1, attn_output4, atol=atol))                
        print("EFFICIENT vs CUDNN: ", torch.allclose(attn_output2, attn_output4, atol=atol))
        print("SDPAFLASH vs CUDNN: ", torch.allclose(attn_output3, attn_output4, atol=atol))
        print("EFFICIENT vs FLASHCACHE: ", torch.allclose(attn_output2, attn_output5, atol=atol))
        print("SDPAFLASH vs FLASHCACHE: ", torch.allclose(attn_output3, attn_output5, atol=atol))
        print()
    print()                                

output is:

results for sequence length:  1
atol:  0.001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  0.0001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True

atol:  1e-05
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
EFFICIENT vs FLASHCACHE:  True
SDPAFLASH vs FLASHCACHE:  True


results for sequence length:  2
atol:  0.001
MATH vs EFFICIENT:  True
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  True
MATH vs CUDNN:  True
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
**EFFICIENT vs FLASHCACHE:  False**
**SDPAFLASH vs FLASHCACHE:  False**

atol:  0.0001
MATH vs EFFICIENT:  False
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  False
MATH vs CUDNN:  False
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
**EFFICIENT vs FLASHCACHE:  False**
**SDPAFLASH vs FLASHCACHE:  False**

atol:  1e-05
MATH vs EFFICIENT:  False
EFFICIENT vs SDPAFLASH:  True
MATH vs SDPAFLASH:  False
MATH vs CUDNN:  False
EFFICIENT vs CUDNN:  True
SDPAFLASH vs CUDNN:  True
**EFFICIENT vs FLASHCACHE:  False**
**SDPAFLASH vs FLASHCACHE:  False**

Again for seqlen=1 everything is fine
For Seqlen=2 (or more) FLASHCACHE does not match SDPACACHE (or EFFICIENT) anymore

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant