Skip to content

Commit

Permalink
Flatten cache and add flashattention (#2676)
Browse files Browse the repository at this point in the history
* add flash attention

* add flash attention

* fix

* remove paged attention prefill

* remove auto tuning

* fix triton2

* fix ut

* fix sliding window

* fill last block
  • Loading branch information
grimoire authored Nov 8, 2024
1 parent a4012ef commit 2bed018
Show file tree
Hide file tree
Showing 14 changed files with 1,460 additions and 840 deletions.
98 changes: 74 additions & 24 deletions lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass
from typing import Literal

import torch

from lmdeploy.pytorch.distributed import get_world_rank

from ..attention import AttentionBuilder, AttentionImpl, AttentionMetadata


@dataclass
class TritonAttentionMetadata(AttentionMetadata):
"""triton attention metadata."""
pass
is_decoding: bool
block_offsets: torch.Tensor
q_start_loc: torch.Tensor = None
q_seqlens: torch.Tensor = None
kv_start_loc: torch.Tensor = None
kv_seqlens: torch.Tensor = None
fill_seqlens: torch.Tensor = None
quant_policy: Literal[0, 4, 8] = 0
kv_flatten_size: int = None


def _cdiv(a, b):
"""perform div up."""
return (a + b - 1) // b


class TritonAttentionImpl(AttentionImpl[TritonAttentionMetadata]):
Expand Down Expand Up @@ -40,10 +57,14 @@ def __init__(

from lmdeploy.pytorch.kernels.cuda import (alibi_paged_attention_fwd,
fill_kv_cache,
flash_attention_fwd,
flatten_kv_cache,
paged_attention_fwd)
self.fill_kv_cache = fill_kv_cache
self.paged_attention_fwd = paged_attention_fwd
self.alibi_paged_attention_fwd = alibi_paged_attention_fwd
self.flatten_kv_cache = flatten_kv_cache
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_world_rank()
Expand All @@ -69,7 +90,9 @@ def forward(
fill_q_start_loc = q_start_loc
q_seqlens = attn_metadata.q_seqlens
fill_seqlens = q_seqlens
kv_start_loc = attn_metadata.kv_start_loc
kv_seqlens = attn_metadata.kv_seqlens
kv_flatten_size = attn_metadata.kv_flatten_size
quant_policy = attn_metadata.quant_policy
max_q_seqlen = query.numel() // (query.size(-1) * query.size(-2))
fill_max_q_seqlen = max_q_seqlen
Expand All @@ -95,31 +118,58 @@ def forward(
quant_policy=quant_policy,
)

if inplace:
attn_output = query[..., :self.v_head_size]
else:
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_size, )
attn_output = query.new_empty(o_shape)
q_shape = query.shape
o_shape = q_shape[:-1] + (self.v_head_size, )
attn_output = query.new_empty(o_shape)

is_decoding = attn_metadata.is_decoding
if not self.alibi:
self.paged_attention_fwd(
query,
k_cache,
v_cache,
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
)
if is_decoding:
self.paged_attention_fwd(
query,
k_cache,
v_cache,
attn_output,
block_offsets,
kv_seqlens=kv_seqlens,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
)
else:
BLOCK_BS = k_cache.size(1)
# pad one more block to avoid invalid kv visit
out_size = (_cdiv(kv_flatten_size, BLOCK_BS) * BLOCK_BS +
BLOCK_BS)
flatten_k, flatten_v = self.flatten_kv_cache(
k_cache,
v_cache,
kv_seqlens,
block_offsets,
start_loc=kv_start_loc,
out_size=out_size,
out_dtype=query.dtype,
k_scales_zeros=k_scales_zeros,
v_scales_zeros=v_scales_zeros,
quant_policy=quant_policy,
)
self.flash_attention_fwd(
query,
flatten_k,
flatten_v,
attn_output,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
max_seqlen=max_q_seqlen,
window_size=self.sliding_window,
sm_scale=self.scale,
logit_softcapping=self.logit_softcapping,
)
else:
self.alibi_paged_attention_fwd(
query,
Expand Down
20 changes: 18 additions & 2 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,20 @@ def update_step_context(cls, step_context):
attn_meta_cls = cls.get_attention_metadata_cls()
q_seqlens = step_context.q_seqlens
q_start_loc = q_seqlens.cumsum(0) - q_seqlens
kv_seqlens = step_context.kv_seqlens
kv_start_loc = None
kv_flatten_size = None
if not step_context.is_decoding:
kv_start_loc = kv_seqlens.cumsum(0) - kv_seqlens
kv_flatten_size = kv_seqlens.sum().item()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=step_context.kv_seqlens,
kv_start_loc=kv_start_loc,
kv_seqlens=kv_seqlens,
kv_flatten_size=kv_flatten_size,
quant_policy=step_context.kv_quant_policy,
)

Expand All @@ -120,12 +128,20 @@ def update_step_context(cls, step_context):
for idx, state in enumerate(step_context.cross_attention_states):
if state is not None:
fill_seqlens[idx] = state.shape[-2]
cross_kv_seqlens = step_context.cross_kv_seqlens
cross_kv_start_loc = None
cross_kv_flatten_size = None
if not step_context.is_decoding and cross_kv_seqlens is not None:
cross_kv_start_loc = cross_kv_seqlens.cumsum(0) - cross_kv_seqlens
cross_kv_flatten_size = cross_kv_seqlens.sum().item()
cross_attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=step_context.cross_kv_seqlens,
kv_start_loc=cross_kv_start_loc,
kv_seqlens=cross_kv_seqlens,
kv_flatten_size=cross_kv_flatten_size,
fill_seqlens=fill_seqlens,
quant_policy=step_context.kv_quant_policy,
)
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/kernels/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .alibi_pagedattention import alibi_paged_attention_fwd
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .flashattention import flash_attention_fwd
from .flatten_kv_cache import flatten_kv_cache
from .fused_moe import fused_moe
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
Expand All @@ -24,4 +26,6 @@
'per_channel_quant',
'per_token_quant_int8',
'rms_norm_dynamic_quant',
'flash_attention_fwd',
'flatten_kv_cache',
]
Loading

0 comments on commit 2bed018

Please sign in to comment.