You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Shape of $Q$ is $(batch, num\_heads, seqlen\_q, head\_dim)$ and shape of $K$ and $V$ are $(batch, num\_kv\_heads, seqlen\_kv, head\_dim)$.
But in this operator, shape of $Q$ will be $(batch, seqlen\_q, num\_heads, head\_dim)$ and shape of $K$ and $V$ will be $(batch, seqlen\_kv, num\_kv\_heads, head\_dim)$. So we need to do some transpose before applying attention.
Attributes/Parameters
num_heads: int
Number of heads
head_dim: int
Dimension of each head, where $head\_dim * num\_heads = hidden\_dim$
is_causal: bool
Whether apply casual mask when sequence length > 1.
is_alibi: bool(default: False)
Whether apply alibi mask within the operator. Do not need to set alibi mask in attn_mask when it is True
num_kv_heads: int(default: 0)
For Grouped-Query Attention. If num_kv_heads and num_heads are not equal, we should repeat key and value num_heads/num_kv_heads times before applying ${\rm MHA}$ for each token. num_heads must be divisible by num_kv_heads. Default is 0, and at this point, num_heads is used as num_kv_heads.
Optional custom mask. If shape is $(seqlen\_q, >=seqlen\_kv)$, attn_mask will be broadcasted.
Note: The last dim of mask could be bigger than $seqlen\_kv$, because in some flash attention implement may force it to aligned with specific padding value.
Shape: $(seqlen\_q, >=seqlen\_kv)$ or $(num\_heads, seqlen\_q, >=seqlen\_kv)$ or $(batch, num\_heads, seqlen\_q, >=seqlen\_kv)$