The original definition of MultiHeadAttention
refers to here.
For dynamic_batching.MultiHeadAttention
, sequence length and key/value length of each batch are different.
Because dynamic batching will combine decoding attention and first-fill attention together, we have to pass some information to separate input into two part by decoding_batches
.
The first part in the front is decoding part, whose size is equal to decoding_batches
, will never apply causal mask. And the batches remain are first-fill part, who will be apply with causal mask if is_causal
is True
.
Number of heads
Dimension of each head, where
Whether apply casual mask when sequence length > 1.
Whether apply alibi mask within the operator. Do not need to set alibi mask in attn_mask
when it is True
For Grouped-Query Attention. If num_kv_heads
and num_heads
are not equal, we should repeat key and value num_heads/num_kv_heads
times before applying num_heads
must be divisible by num_kv_heads
. Default is 0, and at this point, num_heads
is used as num_kv_heads
.
Input Query tensor
Shape:
Input Key tensor
Shape:
Input Value tensor
Shape:
seqstarts[:B]
contains the position of the first token in query
for each batch. And seqstarts[B]
contains the total length of query
.
Note that seqstarts[b+1]-seqstarts[b]
can calculate out the sequence length of batch
Shape:
kvstarts[:B]
contains the position of the first token in key
and value
for each batch. And kvstarts[B]
contains the total length of key
and value
.
Note that kvstarts[b+1]-kvstarts[b]
can calculate out the key and value length of batch
Shape:
Describe how many batches in front are being decoded, those who are not need causal mask.
Maximum sequence length of query
, equal to max(seqstarts[1:]-seqstarts[:B])
. For parallel computing.
Maximum sequence length of key
and value
, equal to max(kvstarts[1:]-kvstarts[:B])
. For parallel computing.
Optional custom mask.
seqlens=seqstarts[1:]-seqstarts[:B]
is a sequence contains length of query
for each batch.
kvlens=kvstarts[1:]-kvstarts[:B]
is a sequence contains length of key
and value
for each batch.
Note: The last dim of mask could be bigger than
Shape:
Output feature of attention result
Shape: