Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] Faster Custom Paged Attention kernels #12348

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jan 23, 2025

Description

This PR implements a faster Custom Paged Attention (CPA) kernel based on mfma16x16x16 instructions.
This feature is from ROCm/vllm (ROCm#372).

End-to-End Performance gain

Model: Llama-3.1-70B-Instruct
Tensor Parallelism: 1
GPU: MI300X

CPA Version Input length Output length KV-cache-dtype Quantization Prompt numbers Req/s Total Tokens/s Output Tokens/s
before changes 128 128 fp8_e4m3 fp8 200 13.05 3340.6 1670.3
before changes 128 256 fp8_e4m3 fp8 200 7.56 2901.31 1934.21
before changes 128 2048 fp8_e4m3 fp8 200 0.78 1698.35 1598.45
before changes 512 128 fp8_e4m3 fp8 200 6.44 4122.57 824.51
before changes 512 256 fp8_e4m3 fp8 200 4.48 3443.46 1147.82
before changes 512 2048 fp8_e4m3 fp8 200 0.66 1696.64 1357.31
before changes ShareGPT fp8_e4m3 fp8 1000 6.22 2574.19 1234.64
optimized 128 128 fp8_e4m3 fp8 200 15.11 3867.75 1933.87
optimized 128 256 fp8_e4m3 fp8 200 9.01 3459.98 2306.65
optimized 128 2048 fp8_e4m3 fp8 200 1.2 2609.04 2455.57
optimized 512 128 fp8_e4m3 fp8 200 7.33 4694.05 938.81
optimized 512 256 fp8_e4m3 fp8 200 5.5 4223.29 1407.76
optimized 512 2048 fp8_e4m3 fp8 200 1.03 2648.55 2118.84
optimized ShareGPT fp8_e4m3 fp8 1000 7.45 3081.14 1477.79

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Jan 23, 2025
vllmellm and others added 2 commits January 23, 2025 08:50
Ported ROCm/vllm changes to upstream vLLM

This commit manually ports changes from ROCm/vllm (ROCm#372) to upstream vLLM.
The original work was done by sanyalington.

Co-authored-by: sanyalington <[email protected]>

Signed-off-by: vllmellm <[email protected]>
@tjtanaa tjtanaa force-pushed the port-rocm-cpa-credit branch 2 times, most recently from 9be5f70 to f57dcb9 Compare January 23, 2025 08:57
@tjtanaa tjtanaa force-pushed the port-rocm-cpa-credit branch from f57dcb9 to 4f71b54 Compare January 23, 2025 09:01
@tjtanaa tjtanaa changed the title [AMD] Faster Custom Paged Attention kernels [ROCm] Faster Custom Paged Attention kernels Jan 23, 2025
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 23, 2025

Regarding to the API changes of paged_attention in csrc/rocm/torch_bindings.cpp. This change only affects ROCm code path and does not interfere with code path of other platform.

 rocm_ops.def(
      "paged_attention(Tensor! out, Tensor exp_sums,"
      "                Tensor max_logits, Tensor tmp_out,"
      "                Tensor query, Tensor key_cache,"
      "                Tensor value_cache, int num_kv_heads,"
      "                float scale, Tensor block_tables,"
      "                Tensor context_lens, int block_size,"
      "                int max_context_len,"
      "                Tensor? alibi_slopes,"
      "                str kv_cache_dtype,"
      "                float k_scale, float v_scale,"
      "                Tensor? fp8_out_scale,"
      "                int partition_size) -> ()");

Seeking advice on handling the variables fp8_out_scale and partition_size.

Situation: Currently these two variables fp8_out_scale and partition_size has been introduced in the Custom Paged Attention ROCm, but they are not in used by higher level abstractions. They are set to fp8_out_scale=None and partition_size=256. The partition_size=256 has been found experimentally to be a good value for MI300.

Option 1:

  • Remove fp8_out_scale from csrc/rocm/attention.cu
  • Hard code partition_size to be 256 in csrc/rocm/attention.cu.
    This avoid changing the paged_attention API in csrc/rocm/torch_bindings.cpp

Option 2:

  • Keep the variables as is, and mark TODO: for future feature update to remember introducing fp8 scaling strategy for ROCm.
  • Set fp8_out_scale=None and partition_size=256 when calling ops.paged_attention_rocm in vllm/attention/backends/rocm_flash_attn.py

We have implemented Option 1.

@hongxiayang
Copy link
Collaborator

@tjtanaa Please fix the DCO error:
Ensure you have a local copy of your branch by checking out the pull request locally via command line.
In your local branch, run: git rebase HEAD~4 --signoff
Force push your changes to overwrite the branch: git push --force-with-lease origin port-rocm-cpa-credit

…iminate the need for additional argumnets (partition_size and fp8_output_scale) in its api.

Signed-off-by: vllmellm <[email protected]>
Copy link

mergify bot commented Jan 24, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 24, 2025
@mergify mergify bot removed the needs-rebase label Jan 24, 2025
… and code documentation. updated its unittest to match the correct partition size based on paged attention versions as well as platform type.

Signed-off-by: vllmellm <[email protected]>
@tjtanaa tjtanaa marked this pull request as ready for review January 27, 2025 12:27
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jan 27, 2025

@tjtanaa Please fix the DCO error: Ensure you have a local copy of your branch by checking out the pull request locally via command line. In your local branch, run: git rebase HEAD~4 --signoff Force push your changes to overwrite the branch: git push --force-with-lease origin port-rocm-cpa-credit

@hongxiayang We find that rebasing is hard as we had merged from main. In the process of fixing the DCO, we had to resolve merge conflict twice, and will require us to test everything again. It seems there are ways to override the DCO during merge. Could we get more input from vLLM maintainers about DCO issue.

@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Jan 27, 2025
@tjtanaa tjtanaa force-pushed the port-rocm-cpa-credit branch from e8e548c to a1a36f3 Compare January 28, 2025 03:08
Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. LGTM.

@hongxiayang
Copy link
Collaborator

I also verified the throughput numbers using the image built from @tjtanaa 's branch.

@hongxiayang
Copy link
Collaborator

@tjtanaa Can you work with @DarkLight1337 to see what else is needed in order to merge this PR?
Thanks for your effort for upstreaming this and fixing the test and clean up other spelling errors as well.

@DarkLight1337
Copy link
Member

@tlrmchlsmth @WoosukKwon can either of you take a look at this PR?

Comment on lines +148 to +149
global PARTITION_SIZE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why make PARTITION_SIZE a global here? Not sure what PARTITION_SIZE does, or why would it be different on RoCM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tlrmchlsmth This line is defined to tell the compiler that the PARTITION_SIZE within the test scope test_paged_attention function that PARTITION_SIZE should be from the global variable. This is needed after we defined a line https://github.com/vllm-project/vllm/blob/9501972249ca7dca0704bceb4308163f30999a6d/tests/kernels/test_attention.py#L217C36-L219C49
to reassign PARTITION_SIZE to have the value of PARTITION_SIZE_ROCM, which causes the compiler to think that PARTITION_SIZE is a local variable.

PARTITION_SIZE_ROCM is a performance-tuned hyperparameter for ROCm custom paged attention. That's why it is different from the PARTITION_SIZE on other platform.

@tlrmchlsmth
Copy link
Collaborator

I'll take a look tomorrow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation frontend rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants