Skip to content

Commit

Permalink
[misc] add forward context for attention (vllm-project#9029)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Oct 3, 2024
1 parent 63e3993 commit 9aaf14c
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 334 deletions.
56 changes: 7 additions & 49 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
import torch

import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -112,36 +112,17 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = torch.ops.vllm.flash_attn_varlen_func(
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down
Loading

0 comments on commit 9aaf14c

Please sign in to comment.