Skip to content

Commit

Permalink
fix decoding kernel for deepseekv2 (#2688)
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire authored Nov 6, 2024
1 parent cc14215 commit 354028b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/kernels/cuda/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def _fwd_grouped_split_kernel(
cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
mask_h = mask_h & (cur_head < num_heads_q)
if BLOCK_H < kv_group_num:
cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num

q_seqlen = 1
kv_seqlen = tl.load(KV_seqlens + cur_batch)
Expand Down Expand Up @@ -366,6 +368,8 @@ def _fwd_grouped_split_quant_kernel(
cur_head = cur_kv_head * HEAD_PER_CTA + tl.arange(0, BLOCK_H)
mask_h = cur_head < cur_kv_head * HEAD_PER_CTA + HEAD_PER_CTA
mask_h = mask_h & (cur_head < num_heads_q)
if BLOCK_H < kv_group_num:
cur_kv_head = (cur_kv_head * HEAD_PER_CTA) // kv_group_num

q_seqlen = 1
kv_seqlen = tl.load(KV_seqlens + cur_batch)
Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch/kernel/test_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def conti_gt(self, gt, seq_lens):

@pytest.mark.parametrize('feat_dim', [48, 32], indirect=True)
@pytest.mark.parametrize('feat_dim_v', [32], indirect=True)
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(8, 2), (2, 2)],
@pytest.mark.parametrize(['num_heads_q', 'num_heads_k'], [(128, 2), (8, 2),
(2, 2)],
indirect=True)
@pytest.mark.parametrize(['seq_lens', 'history_lens'],
[([30, 50, 70, 90], [50, 40, 30, 20]),
Expand Down

0 comments on commit 354028b

Please sign in to comment.