Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention 3 does not use Dropout_p? #1377

Open
nighting0le01 opened this issue Dec 9, 2024 · 6 comments
Open

Flash attention 3 does not use Dropout_p? #1377

nighting0le01 opened this issue Dec 9, 2024 · 6 comments

Comments

@nighting0le01
Copy link

hi i was trying to train a model by swapping out FA2 (SDPA) with FA3, however it does not use dropout_p?
reference:

def flash_attn_func(

also the speedup i get in forward is only aroudn 10-20%

@nighting0le01
Copy link
Author

@tridao could you please confirm why dropout_p is not included in FA3?

@nighting0le01
Copy link
Author

flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Arguments:
    q: (batch_size, seqlen, nheads, headdim)
    k: (batch_size, seqlen, nheads_k, headdim)
    v: (batch_size, seqlen, nheads_k, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
        (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
        is added to the attention score of query i and key j.
    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
        which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
    out: (batch_size, seqlen, nheads, headdim).
"""

doc in the README says dropout can be passed??

@tridao
Copy link
Contributor

tridao commented Dec 9, 2024

dropout is not supported in FA3

@nighting0le01
Copy link
Author

@tridao i tried swapping a model which had FA2 (SDPA) with FA3, and i see numerical mismatch in the results

@nighting0le01
Copy link
Author

do you have any tests that confirm numerical equivalency

@tridao
Copy link
Contributor

tridao commented Dec 10, 2024

We have tests. If there's a mismatch, please help us by providing a short script to reproduce the error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants