Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA3][Varlen] bug for head_dim not in [64, 128, 256] for varlen #1214

Open
YLGH opened this issue Sep 9, 2024 · 0 comments
Open

[FA3][Varlen] bug for head_dim not in [64, 128, 256] for varlen #1214

YLGH opened this issue Sep 9, 2024 · 0 comments

Comments

@YLGH
Copy link

YLGH commented Sep 9, 2024

Hi,

not sure if head_dim != [64, 128, 256] is yet a fully supported feature, but wanted to raise attention that there might be some kind of bug for varying head_dims (I tested 192 below). It seems that FA3 is failing to write out the prefix of some values for some heads.

Below I'm comparing the outputs of FA2 and FA3,

import torch
with torch.device("cuda"):
    q_bs = 3
    kv_bs = 7
    num_heads = 16
    head_dim = 192
    q = torch.randn((q_bs, num_heads, head_dim), dtype=torch.half)
    k = torch.randn((kv_bs, num_heads, head_dim), dtype=torch.half)
    v = torch.randn((kv_bs, num_heads, head_dim), dtype=torch.half)
    
    cu_seqlens_q = torch.tensor([0,q_bs], dtype=torch.int32)
    cu_seqlens_k = torch.tensor([0,kv_bs], dtype=torch.int32)
    
    
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
    fa2_out = _flash_attn_varlen_forward(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        q_bs,
        kv_bs,
        dropout_p=0.0,
        softmax_scale=0.1147213876247406,
        causal=True,
        window_size=(-1, -1),
        alibi_slopes=None,
        return_softmax=False,
    )[0]
    
    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3


    fa3_out = flash_attn_varlen_func_v3(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        q_bs,
        kv_bs,
        causal=True,
        softmax_scale=0.1147213876247406
    )[0]
    print(fa2_out[-1][0])
    import torch
with torch.device("cuda"):
    q_bs = 3
    kv_bs = 7
    num_heads = 16
    head_dim = 192
    q = torch.randn((q_bs, num_heads, head_dim), dtype=torch.half)
    k = torch.randn((kv_bs, num_heads, head_dim), dtype=torch.half)
    v = torch.randn((kv_bs, num_heads, head_dim), dtype=torch.half)
    
    cu_seqlens_q = torch.tensor([0,q_bs], dtype=torch.int32)
    cu_seqlens_k = torch.tensor([0,kv_bs], dtype=torch.int32)
    
    
    from flash_attn.flash_attn_interface import _flash_attn_varlen_forward
    fa2_out = _flash_attn_varlen_forward(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        q_bs,
        kv_bs,
        dropout_p=0.0,
        softmax_scale=0.1147213876247406,
        causal=True,
        window_size=(-1, -1),
        alibi_slopes=None,
        return_softmax=False,
    )[0]
    
    from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3


    fa3_out = flash_attn_varlen_func_v3(
        q,
        k,
        v,
        cu_seqlens_q,
        cu_seqlens_k,
        q_bs,
        kv_bs,
        causal=True,
        softmax_scale=0.1147213876247406
    )[0]
    
    print(fa2_out[-1][0])
tensor([-2.4658e-01,  8.4082e-01,  1.8982e-01,  1.8115e-01,  6.6602e-01,
         7.1924e-01, -1.5283e+00, -2.4158e-01,  9.7119e-01,  1.0834e-01,
         9.1980e-02, -6.3770e-01,  1.2197e+00, -1.3984e+00, -6.7285e-01,
        -1.9019e-01, -1.3926e+00,  2.6306e-02, -8.2581e-02, -2.0850e-01,
         5.9912e-01, -3.3252e-01,  2.6840e-02,  2.0508e-01, -5.4092e-03,
        -3.1226e-01, -2.1460e-01, -2.2461e-01,  9.2285e-02,  3.8916e-01,
        -8.7061e-01,  1.2091e-01, -9.3689e-02, -4.8309e-02, -3.8647e-01,
        -4.1113e-01, -9.0271e-02,  5.0293e-01, -6.0840e-01,  2.7368e-01,
        -3.0859e-01,  1.0144e-01, -4.8706e-01,  9.1406e-01,  2.7515e-01,
         6.9641e-02, -2.2620e-01, -5.3613e-01,  3.6792e-01, -4.2407e-01,
        -9.0088e-02, -3.2861e-01,  1.7932e-01,  2.3120e-01,  8.8477e-01,
        -1.4709e-01,  6.1084e-01, -5.1953e-01,  6.9385e-01, -1.5735e-01,
        -3.6377e-01, -1.8970e-01, -3.2080e-01, -1.9434e-01, -1.1310e-01,
         2.2168e-01, -3.5791e-01, -8.8623e-01, -5.9375e-01, -8.3557e-02,
         3.8062e-01,  5.7715e-01, -3.1641e-01, -4.7363e-01, -9.1248e-02,
         8.4766e-01,  3.9160e-01,  1.4526e-01, -2.9639e-01, -3.5327e-01,
         1.1279e-01,  1.3904e-01,  3.8330e-01,  4.7656e-01,  1.0566e+00,
         7.1631e-01,  2.4792e-01,  1.0342e+00, -5.9521e-01, -5.9717e-01,
        -7.6485e-03,  1.2549e+00, -7.0996e-01, -1.2830e-01,  7.7979e-01,
         1.3168e-02, -1.0809e-01, -1.3008e+00,  2.3877e-01, -9.2627e-01,
         6.6846e-01, -4.2065e-01,  2.9224e-01, -2.6904e-01,  9.0625e-01,
         2.9932e-01,  2.6318e-01,  4.5020e-01, -3.9648e-01, -2.3486e-01,
         2.2644e-02, -4.0503e-01, -1.0664e+00, -1.0583e-01,  1.1416e+00,
         4.1626e-01,  3.9697e-01, -8.1445e-01, -1.4014e-01, -7.9346e-01,
         1.9031e-01, -2.7866e-03,  1.2549e-01, -1.1240e+00, -2.0752e-01,
        -1.7932e-01,  4.5898e-01,  2.3865e-01,  2.2247e-02, -2.1350e-01,
        -2.7374e-02, -3.4521e-01,  2.4194e-01, -1.2741e-02, -3.7085e-01,
         4.3799e-01,  5.4492e-01,  2.2791e-01, -1.6296e-02,  2.7393e-01,
         2.0703e-01, -1.6577e-01, -5.5075e-05, -1.2227e+00,  7.6050e-02,
         9.3555e-01,  2.6636e-01,  3.2642e-01,  1.2964e-01,  6.9580e-02,
         4.1943e-01,  3.1055e-01,  2.3206e-01,  2.2314e-01,  1.7981e-01,
        -4.3628e-01, -1.0730e-01, -1.1066e-01,  3.5352e-01,  9.9182e-02,
         1.5808e-02, -2.3914e-01,  1.2598e-01,  2.9028e-01, -3.8354e-01,
        -1.8359e-01,  1.7273e-01, -2.4719e-01, -4.4043e-01, -3.0322e-01,
         1.9592e-01, -3.4912e-02, -5.0244e-01,  4.8218e-01, -1.6541e-01,
        -3.0298e-01, -4.3848e-01, -1.7334e-01,  2.5757e-01, -7.0923e-02,
         3.0444e-01, -1.4526e-02, -4.8413e-01, -1.0950e-01, -3.3032e-01,
        -7.0215e-01,  1.3477e-01,  1.0381e+00, -7.3318e-03, -4.7583e-01,
        -4.3848e-01, -3.8794e-01], device='cuda:0', dtype=torch.float16)
    
    print(fa3_out[-1][0])
    """
tensor([ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.1310e-01,
         2.2168e-01, -3.5791e-01, -8.8623e-01, -5.9375e-01, -8.3557e-02,
         3.8062e-01,  5.7715e-01, -3.1641e-01, -4.7363e-01, -9.1248e-02,
         8.4766e-01,  3.9160e-01,  1.4526e-01, -2.9639e-01, -3.5327e-01,
         1.1279e-01,  1.3904e-01,  3.8330e-01,  4.7656e-01,  1.0566e+00,
         7.1631e-01,  2.4792e-01,  1.0342e+00, -5.9521e-01, -5.9717e-01,
        -7.6485e-03,  1.2549e+00, -7.0996e-01, -1.2830e-01,  7.7979e-01,
         1.3168e-02, -1.0809e-01, -1.3008e+00,  2.3877e-01, -9.2627e-01,
         6.6846e-01, -4.2065e-01,  2.9224e-01, -2.6904e-01,  9.0625e-01,
         2.9932e-01,  2.6318e-01,  4.5020e-01, -3.9648e-01, -2.3486e-01,
         2.2644e-02, -4.0503e-01, -1.0664e+00, -1.0583e-01,  1.1416e+00,
         4.1626e-01,  3.9697e-01, -8.1445e-01, -1.4014e-01, -7.9346e-01,
         1.9031e-01, -2.7866e-03,  1.2549e-01, -1.1240e+00, -2.0752e-01,
        -1.7932e-01,  4.5898e-01,  2.3865e-01,  2.2247e-02, -2.1350e-01,
        -2.7374e-02, -3.4521e-01,  2.4194e-01, -1.2741e-02, -3.7085e-01,
         4.3799e-01,  5.4492e-01,  2.2791e-01, -1.6296e-02,  2.7393e-01,
         2.0703e-01, -1.6577e-01, -5.5075e-05, -1.2227e+00,  7.6050e-02,
         9.3555e-01,  2.6636e-01,  3.2642e-01,  1.2964e-01,  6.9580e-02,
         4.1943e-01,  3.1055e-01,  2.3206e-01,  2.2314e-01,  1.7981e-01,
        -4.3628e-01, -1.0730e-01, -1.1066e-01,  3.5352e-01,  9.9182e-02,
         1.5808e-02, -2.3914e-01,  1.2598e-01,  2.9028e-01, -3.8354e-01,
        -1.8359e-01,  1.7273e-01, -2.4719e-01, -4.4043e-01, -3.0322e-01,
         1.9592e-01, -3.4912e-02, -5.0244e-01,  4.8218e-01, -1.6541e-01,
        -3.0298e-01, -4.3848e-01, -1.7334e-01,  2.5757e-01, -7.0923e-02,
         3.0444e-01, -1.4526e-02, -4.8413e-01, -1.0950e-01, -3.3032e-01,
        -7.0215e-01,  1.3477e-01,  1.0381e+00, -7.3318e-03, -4.7583e-01,
        -4.3848e-01, -3.8794e-01], device='cuda:0', dtype=torch.float16)
    """
    torch.testing.assert_close(fa2_out, fa3_out)
    # ^ above fails
   
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant