Skip to content

Commit

Permalink
Fix PA custom and PA v2 tests and partition sizes (#196)
Browse files Browse the repository at this point in the history
* update custom PA kernel with support for fp8 kv cache dtype; change custom PA partition size to 512 to prefer throughput scenarios at cost of latency

* Fix lint

* Fix BF16 with FP8 KV cache (scaled conversion incorrectly done in fp16)

* Fix custom PA tests

* Merge branch 'main' of [email protected]:ROCm/vllm.git into mawong/fix_custom_pa_tests

* Fix partition sizes for PAv2, PAcustom

* Fix linting

* Fix a few names and variable scopes

* Rename custom to rocm as per suggestion

---------

Co-authored-by: Shomy Sanyal <[email protected]>
  • Loading branch information
mawong-amd and sanyalington authored Sep 18, 2024
1 parent d21cf99 commit a67b65b
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 444 deletions.
8 changes: 4 additions & 4 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from vllm import _custom_ops as ops
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)
create_kv_caches_with_random, is_hip)

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 256
PARTITION_SIZE = 512


@torch.inference_mode()
Expand Down Expand Up @@ -80,9 +80,9 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if not args.custom_paged_attn:
if is_hip() and not args.custom_paged_attn:
global PARTITION_SIZE
PARTITION_SIZE = 512
PARTITION_SIZE = 1024
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
Expand Down
255 changes: 65 additions & 190 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@

# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]

BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
Expand Down Expand Up @@ -114,7 +113,8 @@ def ref_single_query_cached_kv_attention(
output[i].copy_(out, non_blocking=True)


@pytest.mark.parametrize("version", ["v1", "v2"])
@pytest.mark.parametrize(
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
Expand All @@ -137,7 +137,8 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
if ((kv_cache_dtype == "fp8" and head_size % 16)
or (version == "rocm" and head_size not in (64, 128))):
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
Expand Down Expand Up @@ -208,7 +209,9 @@ def test_paged_attention(
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))

elif version == "v2":
elif version in ("v2", "rocm"):
if is_hip():
PARTITION_SIZE = 1024 if version == "v2" else 512
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
Expand All @@ -221,32 +224,62 @@ def test_paged_attention(
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)

opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query, key_cache,
value_cache, num_kv_heads, scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))
if version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)

opcheck(torch.ops._C.paged_attention_v2,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
cond=(head_size == HEAD_SIZES[0]))

else:
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)

opcheck(torch.ops._rocm_C.paged_attention,
(output, exp_sums, max_logits, tmp_output, query,
key_cache, value_cache, num_kv_heads, scale, block_tables,
seq_lens, block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale),
cond=(head_size == 64))

else:
raise AssertionError(f"Unknown version: {version}")
Expand Down Expand Up @@ -330,173 +363,15 @@ def ref_multi_query_kv_attention(
return torch.cat(ref_outputs, dim=0)


@pytest.mark.parametrize("version", ["rocm"])
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", [64, 128]) # only test 64 128
@pytest.mark.parametrize("use_alibi", USE_ALIBI)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(not is_hip(), reason="only for rocm")
def test_paged_attention_rocm(
kv_cache_factory,
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)

assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
#context_lens = [8192 for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int)
#print('>>> ctx lens', context_lens)

# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
block_tables = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int)

# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]

# TODO(charlifu) enable fp8 kv cache
# Using default kv_scale
# kv_scale = 1.0

# Call the paged attention kernel.
output = torch.empty_like(query)
PARTITION_SIZE_ROCM = 256
num_partitions = ((max_context_len + PARTITION_SIZE_ROCM - 1) //
PARTITION_SIZE_ROCM)
assert PARTITION_SIZE_ROCM % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
if version == "rocm":
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
)
else:
raise AssertionError(f"Unknown version: {version}")

# Run the reference implementation.
if kv_cache_dtype == "fp8":
# Convert cache data back to dtype.
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x,
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache

value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=device)
ops.convert_fp8(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache

ref_output = torch.empty_like(query)
ref_single_query_cached_kv_attention(
ref_output,
query,
num_queries_per_kv,
key_cache,
value_cache,
block_tables,
context_lens,
scale,
alibi_slopes,
)

# NOTE(woosuk): Due to the kernel-level differences in the two
# implementations, there is a small numerical difference in the two
# outputs. Thus, we use a relaxed tolerance for the test.
atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5

# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
# so we use a relaxed tolerance for the test.
atol, rtol = 1e-4, 1e-5
if dtype == torch.bfloat16:
atol, rtol = 2e-4, 1e-5
if use_alibi:
if dtype == torch.half:
atol, rtol = 5e-4, 1e-5
if dtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)


# TODO(woosuk): Add tests for USE_ALIBI=True.
@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(is_hip(), reason="skip for rocm")
@pytest.mark.skipif(is_hip(),
reason="Xformers backend is not supported on ROCm.")
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
Expand Down
Loading

0 comments on commit a67b65b

Please sign in to comment.