Skip to content

Commit

Permalink
[Neuron][Kernel] Support Longer Sequences in NKI-based Flash PagedAtt…
Browse files Browse the repository at this point in the history
…ention and Improve Efficiency (vllm-project#12921)

Signed-off-by: Lingfan Yu <[email protected]>
  • Loading branch information
lingfanyu authored Feb 12, 2025
1 parent 842b0fd commit e92694b
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 180 deletions.
118 changes: 67 additions & 51 deletions tests/neuron/test_prefix_prefill.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

import random
from typing import Optional

import pytest
Expand Down Expand Up @@ -171,19 +170,31 @@ def ref_context_attention(
return output


@pytest.mark.parametrize(
"block_size, large_tile_size",
[
(32, 2048), # 64 blocks
(32, 4096), # 128 blocks
(32, 8192), # 256 blocks
(64, 8192), # 128 blocks
],
)
@pytest.mark.parametrize(
"num_heads,num_queries_per_kv,head_size,mixed_precision",
[
(4, 2, 8, False),
(4, 2, 8, True),
(32, 8, 64, True),
(16, 2, 128, True),
],
)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
block_size: int,
large_tile_size,
mixed_precision: bool,
) -> None:
import os
Expand All @@ -192,40 +203,46 @@ def test_contexted_kv_attention(

from vllm.attention.ops.nki_flash_attn import flash_attn_varlen_nkifunc

assert large_tile_size % block_size == 0

device = xm.xla_device()

os.environ["NEURON_CC_FLAGS"] = (
" --model-type=transformer -O1 "
" --internal-hlo2tensorizer-options='--verify-hlo' ")
compiler_flags = [
"--model-type=transformer -O1",
"--internal-hlo2tensorizer-options='--verify-hlo'",
"--retry_failed_compilation",
]
compiler_flags_str = " ".join(compiler_flags)
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str

random.seed(0)
torch.manual_seed(0)
torch.set_printoptions(sci_mode=False)

min_ctx_len = 2
max_ctx_len = 64
min_query_len = 2
max_query_len = 64
prefill_batch_size = 2
decode_batch_size = 6
min_ctx_len = 32
max_ctx_len = 1024
min_query_len = 16
max_query_len = 512
prefill_batch_size = 4
decode_batch_size = 12
batch_size = prefill_batch_size + decode_batch_size
block_size = 32
max_model_len = (max_query_len + max_ctx_len) * 4

max_block_per_request = max_model_len // block_size
dtype = torch.float32
cache_size = (batch_size * max_block_per_request) + 2
ctx_lens = [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(prefill_batch_size)
] + [
random.randint(min_ctx_len, max_ctx_len)
for _ in range(decode_batch_size)
]
query_lens = [
random.randint(min_query_len, max_query_len)
for _ in range(prefill_batch_size)
] + [1 for _ in range(decode_batch_size)]
prefill_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (prefill_batch_size, ),
dtype=torch.long).tolist()
decode_ctx_lens = torch.randint(min_ctx_len,
max_ctx_len + 1, (decode_batch_size, ),
dtype=torch.long).tolist()
ctx_lens = prefill_ctx_lens + decode_ctx_lens
query_lens = torch.randint(
min_query_len,
max_query_len + 1,
(prefill_batch_size, ),
dtype=torch.long,
).tolist() + [1 for _ in range(decode_batch_size)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv

Expand Down Expand Up @@ -254,7 +271,6 @@ def test_contexted_kv_attention(
values = values[torch.randperm(cache_size)]
block_table = values[:batch_size * max_block_per_request].view(
batch_size, max_block_per_request)
torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
dtype=torch.long),
Expand Down Expand Up @@ -311,9 +327,7 @@ def test_contexted_kv_attention(
# build neuron program
return_debug_tensors = False
B_P_SIZE = 128
LARGE_TILE_SZ = 2048
max_num_queries = (
(sum(query_lens) + block_size - 1) // block_size) * block_size
LARGE_TILE_SZ = large_tile_size

def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
num_blocks):
Expand All @@ -332,26 +346,28 @@ def get_active_block_tables(block_tables, query_lens, seq_lens, block_size,
0,
)

def shift_bit_length(x):
return 1 << (x - 1).bit_length()
def ceil_div(a, b):
return (a + b - 1) // b

def pad_to_multiple(a, b):
return ceil_div(a, b) * b

def pad_to_next_power_of_2(a):
assert a > 0
return 2**int(a - 1).bit_length()

# calculate input shapes
max_num_queries_shifted = shift_bit_length(max_num_queries)
max_num_queries_factor = B_P_SIZE // max_num_queries_shifted
max_num_queries_padded = max_num_queries_shifted * max_num_queries_factor
assert (max_num_queries_padded == B_P_SIZE
), "invalid {max_num_queries_padded=}"
max_num_queries = pad_to_multiple(sum(query_lens), block_size)
max_num_queries = pad_to_next_power_of_2(max_num_queries)
head_size_padded = B_P_SIZE
assert head_size_padded >= head_size
context_lens = torch.tensor(seq_lens) - torch.tensor(query_lens)
num_active_blocks_shifted = shift_bit_length(
((context_lens + block_size - 1) // block_size).sum().item())
num_active_blocks_factor = (LARGE_TILE_SZ // block_size //
num_active_blocks_shifted)
num_active_blocks = num_active_blocks_shifted * num_active_blocks_factor
assert (num_active_blocks *
block_size) == LARGE_TILE_SZ, "invalid {num_active_blocks=}"
num_active_blocks = ceil_div(context_lens, block_size).sum().item()
num_active_blocks = pad_to_multiple(num_active_blocks,
LARGE_TILE_SZ // block_size)
context_kv_len = num_active_blocks * block_size
assert context_kv_len == LARGE_TILE_SZ, f"invalid {context_kv_len=}"
assert (context_kv_len %
LARGE_TILE_SZ == 0), f"invalid context_kv_len={context_kv_len}"

# pad QKV tensors
pad_dims = (
Expand All @@ -360,7 +376,7 @@ def shift_bit_length(x):
0,
0,
0,
max_num_queries_padded - query.shape[0],
max_num_queries - query.shape[0],
)
query = F.pad(query, pad_dims, "constant", 0)
k = F.pad(k, pad_dims, "constant", 0)
Expand Down Expand Up @@ -397,7 +413,7 @@ def shift_bit_length(x):
0,
context_kv_len - prior_mask.shape[1],
0,
B_P_SIZE - prior_mask.shape[0],
max_num_queries - prior_mask.shape[0],
),
"constant",
0,
Expand All @@ -406,9 +422,9 @@ def shift_bit_length(x):
active_mask,
(
0,
B_P_SIZE - active_mask.shape[1],
max_num_queries - active_mask.shape[1],
0,
B_P_SIZE - active_mask.shape[0],
max_num_queries - active_mask.shape[0],
),
"constant",
0,
Expand All @@ -430,6 +446,8 @@ def shift_bit_length(x):
n_kv_head=num_kv_heads,
head_size=head_size,
mixed_precision=mixed_precision,
LARGE_TILE_SZ=LARGE_TILE_SZ,
return_debug_tensors=return_debug_tensors,
)

if return_debug_tensors:
Expand All @@ -439,17 +457,15 @@ def shift_bit_length(x):
output_nki = flash_attn_varlen_nkifunc(*input_args, **input_kwargs)
debug_tensors = []

output_nki = torch.tensor(output_nki).cpu()
debug_tensors = [torch.tensor(dt).cpu() for dt in debug_tensors]

num_actual_tokens = sum(query_lens)
print(f"{num_actual_tokens=}")
# - o: shape (bs, n_heads, seq_q, d) -> (bs, seq_q, n_heads, d)
output_nki = output_nki.permute(
0, 2, 1, 3)[:, :, :, :head_size].cpu()[0, :num_actual_tokens, :, :]
output_nki = output_nki.cpu().permute(0, 2, 1, 3)[:, :, :, :head_size]
output_nki = output_nki[0, :num_actual_tokens, :, :]
output_ref_padded = F.pad(
output_ref,
(0, 0, 0, 0, 0, 0, 0, max_num_queries_padded - output_ref.shape[0]),
(0, 0, 0, 0, 0, 0, 0, max_num_queries - output_ref.shape[0]),
"constant",
0,
)
Expand Down
Loading

0 comments on commit e92694b

Please sign in to comment.