Skip to content

[ROCm] Faster Custom Paged Attention kernels #12348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bcd73cb
added rocm custom paged attention.
vllmellm Jan 22, 2025
841e678
format code and enable test_attention.py in AMD CI
tjtanaa Jan 23, 2025
4f71b54
add author; update requirements-rocm.txt
sanyalington Jan 23, 2025
c200473
[Misc]: removed unnecessary condition in atttention test.
vllmellm Jan 23, 2025
a60ae3f
[Kernel][Hardware][AMD] refactoring rocm custom paged attention to el…
vllmellm Jan 24, 2025
c411b72
Merge remote-tracking branch 'origin/main' into port-rocm-cpa-credit
vllmellm Jan 24, 2025
54b0249
[kernel] fix the format.
vllmellm Jan 24, 2025
a1a36f3
[Kernel][Hardware][AMD] improved rocm custom paged attention accuracy…
vllmellm Jan 27, 2025
21e7306
Merge remote-tracking branch 'origin/main' into port-rocm-cpa-credit
tjtanaa Jan 28, 2025
9501972
format and lint code
tjtanaa Jan 28, 2025
52d914e
address PR feedback and improve code quality:
vllmellm Feb 19, 2025
6cc59e1
change datatype in attention.cu based on reviewer; and revert some fi…
vllmellm Feb 19, 2025
f1c7f94
add comment to the attention.cu
poyenc Feb 19, 2025
0caafa9
convert datatype to int64 to prevent overflow as per reviewed
tjtanaa Feb 20, 2025
2ca1f7a
Merge remote-tracking branch 'origin/main' into port-rocm-cpa-credit
vllmellm Feb 20, 2025
5bce980
remove unused functions from attention.cu
tjtanaa Feb 25, 2025
65f6e3f
add comment to line 471 in attention.cu
tjtanaa Feb 27, 2025
00d6316
merge changes from main
tjtanaa Feb 27, 2025
23c13c7
fix attention.cu bug as main removed some header files
tjtanaa Feb 28, 2025
9b22915
fix linter
tjtanaa Feb 28, 2025
c3f2811
fix linter
tjtanaa Feb 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ echo "Commands:$commands"
#ignore certain kernels tests
if [[ $commands == *" kernels "* ]]; then
commands="${commands} \
--ignore=kernels/test_attention.py \
--ignore=kernels/test_attention_selector.py \
--ignore=kernels/test_blocksparse_attention.py \
--ignore=kernels/test_causal_conv1d.py \
Expand Down
71 changes: 51 additions & 20 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser,
create_kv_caches_with_random)

NUM_BLOCKS = 1024
NUM_BLOCKS = 128 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE_ROCM = 256


@torch.inference_mode()
Expand Down Expand Up @@ -80,6 +81,12 @@ def main(
# Prepare for the paged attention kernel.
output = torch.empty_like(query)
if version == "v2":
if current_platform.is_rocm():
global PARTITION_SIZE
if not args.custom_paged_attn:
PARTITION_SIZE = 1024
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size),
Expand Down Expand Up @@ -123,25 +130,46 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
v_scale,
)
elif version == "v2":
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
if not args.custom_paged_attn:
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.paged_attention_rocm(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
k_scale,
v_scale,
)
else:
raise ValueError(f"Invalid version: {version}")
torch.cuda.synchronize()
Expand Down Expand Up @@ -195,6 +223,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
parser.add_argument("--custom-paged-attn",
action="store_true",
help="Use custom paged attention")
args = parser.parse_args()
print(args)

Expand Down
Loading