The original definition of dynamic_batching.MultiHeadAttention
refers to here.
The original definition of dynamic_batching.KeyValueCache
refers to here.
For dynamic_batching.MultiHeadCacheAttention
, it is just fuse dynamic_batching.MultiHeadAttention
and dynamic_batching.KeyValueCache
together.
Number of heads
Dimension of each head, where
Whether apply casual mask when sequence length > 1.
Whether apply alibi mask within the operator. Do not need to set alibi mask in attn_mask
when it is True
For Grouped-Query Attention. If num_kv_heads
and num_heads
are not equal, we should repeat key and value num_heads/num_kv_heads
times before applying num_heads
must be divisible by num_kv_heads
. Default is 0, and at this point, num_heads
is used as num_kv_heads
.
Number of attention layers.
Attention layer index for cache and scale.
Quantize bit for cache compression. For example, 8 means int8 compression. 0
means disabled.
Quantize scale shared group size.
Define cache indexing mode. Default is zero.
- When
cache_mode
is0
, cache is indexed by offset mode. Shape ofcachestarts
is$(B)$ . For each batch$b$ ,cachestarts[b]
mapping cache begining index in$MaxT$ ofcache
andscale
. Note thatcachestarts[b+1]-cachestarts[b]
can not calculate out the cache length of batch$b$ . - When
cache_mode
is1
,cache
is indexed by page table mode, which called Paged Attention. Shape ofcachestarts
is$(B, MaxP)$ . For each batch$b$ ,cachestarts[b, :]
contains pages' begining index in$MaxT$ ofcache
andscale
.
Example forbatch = 2, page_size = 256
:$$cachestarts=[[0,256,\cdots],[1024,2048,\cdots]]$$
Define data layout of cache
and scale
. Default is zero.
Meaning of numbers:
-
0
:$cache(MaxT,L,2,H,Dh)$ and$scale(MaxT,L,2,H,Dh/quant\_group)$ -
1
:$cache(L,MaxT,2,H,Dh)$ and$scale(L,MaxT,2,H,Dh/quant\_group)$ -
2
:$cache(L,2,MaxT,H,Dh)$ and$scale(L,2,MaxT,H,Dh/quant\_group)$ -
3
:$cache(L,2,H,MaxT,Dh)$ and$scale(L,2,H,MaxT,Dh/quant\_group)$
Page size in Paged Attention(when cache_mode
is 1
)
Input Query tensor
Shape:
Input Key tensor
Shape:
Input Value tensor
Shape:
seqstarts[:B]
contains the position of the first token in query
for each batch. And seqstarts[B]
contains the total length of query
.
Note that seqstarts[b+1]-seqstarts[b]
can calculate out the sequence length of batch
Shape:
kvstarts[:B]
contains the position of the first token in key = cat(past_key, current_key)
and value = cat(past_value, current_value)
for each batch, where key
and value
are originally provided by operator KeyValueCache
. And kvstarts[B]
contains the total length of key
and value
.
Note that kvstarts[b+1]-kvstarts[b]
can calculate out the key and value length of batch
Shape:
Indexing cache position in cache
and scale
. Behavior is determinated by cache_mode
.
Shape:
Sequence position where current_key
and current_value
begining to store of each batch.
Shape:
Describe how many batches in front are being decoded, those who are not need causal mask.
Maximum sequence length of query
, equal to max(seqstarts[1:]-seqstarts[:B])
. For parallel computing.
Maximum sequence length of key
and value
, equal to max(kvstarts[1:]-kvstarts[:B])
. For parallel computing.
Shape: Determinated by cache_layout
.
Contains key and value caches of attention layer. When cache_layout
is 0
, subspace
Shape: determinate by cache_layout
.
Contains key and value cache quantize scales of attention layer. When cache_layout
is 0
, subspace quant_bit
is not zero. Data in this tensor will be modified.
Optional custom mask.
seqlens=seqstarts[1:]-seqstarts[:B]
is a sequence contains length of query
for each batch.
kvlens=kvstarts[1:]-kvstarts[:B]
is a sequence contains length of key
and value
for each batch.
Note: The last dim of mask could be bigger than
Shape:
Output feature of attention result
Shape: