Skip to content

Commit

Permalink
Faster Custom Paged Attention kernels (#372)
Browse files Browse the repository at this point in the history
* integrate new cpa kernel, update tests and benchmark

* added comments to mfma4 kernel

* further comments for mfma16 kernel

* clang-format

* Lint

* add flag for logits rtz conversion and disable by default

* lint

* [Bugfix]: Fix paged attention unit tests of #372 (#389)

* [Bugfix]: fix paged attention tests based on the updated kernels in `csrc/attention/paged_attention_v1.cu`,`csrc/attention/paged_attention_v2.cu` and  `csrc/rocm/attention.cu`.

* improve code documentation.

* lint

---------

Co-authored-by: vllmellm <[email protected]>

---------

Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Joe Shajrawi <[email protected]>
Co-authored-by: TJian <[email protected]>
Co-authored-by: vllmellm <[email protected]>
  • Loading branch information
6 people authored Jan 30, 2025
1 parent 7a292f9 commit 273c949
Show file tree
Hide file tree
Showing 3 changed files with 1,016 additions and 402 deletions.
16 changes: 11 additions & 5 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024 * 1024
NUM_BLOCKS = 128 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256


@torch.inference_mode()
Expand Down Expand Up @@ -78,9 +79,12 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if current_platform.is_rocm() and not args.custom_paged_attn:
if current_platform.is_rocm():
global PARTITION_SIZE
PARTITION_SIZE = 1024
if not args.custom_paged_attn:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
Expand Down Expand Up @@ -163,6 +167,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
None,
PARTITION_SIZE,
)
else:
raise ValueError(f"Invalid version: {version}")
Expand All @@ -176,13 +182,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)
run_benchmark(num_iters=500, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=1000, profile=False)
latency = run_benchmark(num_iters=10000, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


Expand Down
Loading

0 comments on commit 273c949

Please sign in to comment.