-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[v1][Bugfix] Add extra_keys to block_hash for prefix caching #12603
Conversation
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
Signed-off-by: Chen Zhang <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for catching this
commit 5d5071c Author: Lucas Wilkinson <[email protected]> Date: Sat Feb 1 01:13:23 2025 +0000 reduce split kv amount Signed-off-by: Lucas Wilkinson <[email protected]> commit 5fe1d1d Author: Lucas Wilkinson <[email protected]> Date: Sat Feb 1 00:56:45 2025 +0000 format Signed-off-by: Lucas Wilkinson <[email protected]> commit 0d66687 Author: Simon Mo <[email protected]> Date: Fri Jan 31 16:39:19 2025 -0800 Update loader.py Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 5002734 Author: Lucas Wilkinson <[email protected]> Date: Sat Feb 1 00:14:14 2025 +0000 simplification Signed-off-by: Lucas Wilkinson <[email protected]> commit fac827f Merge: db2c583 44bbca7 Author: Lucas Wilkinson <[email protected]> Date: Sat Feb 1 00:09:36 2025 +0000 Merge remote-tracking branch 'origin/main' into mla-fp8 commit db2c583 Author: Lucas Wilkinson <[email protected]> Date: Sat Feb 1 00:06:10 2025 +0000 filter compressed tensor models better Signed-off-by: Lucas Wilkinson <[email protected]> commit e144da8 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 18:41:35 2025 -0500 Update vllm/model_executor/model_loader/loader.py Co-authored-by: Simon Mo <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 1621381 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 18:41:22 2025 -0500 Update vllm/model_executor/model_loader/loader.py Co-authored-by: Simon Mo <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 9829fae Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 23:40:12 2025 +0000 misc Signed-off-by: Lucas Wilkinson <[email protected]> commit 44bbca7 Author: Brian Dellabetta <[email protected]> Date: Fri Jan 31 17:38:48 2025 -0600 [Doc] int4 w4a16 example (vllm-project#12585) Based on a request by @mgoin , with @kylesayrs we have added an example doc for int4 w4a16 quantization, following the pre-existing int8 w8a8 quantization example and the example available in [`llm-compressor`](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py) FIX #n/a (no issue created) @kylesayrs and I have discussed a couple additional improvements for the quantization docs. We will revisit at a later date, possibly including: - A section for "choosing the correct quantization scheme/ compression technique" - Additional vision or audio calibration datasets --------- Signed-off-by: Brian Dellabetta <[email protected]> Co-authored-by: Michael Goin <[email protected]> commit 60808bd Author: Harry Mellor <[email protected]> Date: Fri Jan 31 23:38:35 2025 +0000 [Doc] Improve installation signposting (vllm-project#12575) - Make device tab names more explicit - Add comprehensive list of devices to https://docs.vllm.ai/en/latest/getting_started/installation/index.html - Add `attention` blocks to the intro of all devices that don't have pre-built wheels/images --------- Signed-off-by: Harry Mellor <[email protected]> commit fc54214 Author: Ryan Nguyen <[email protected]> Date: Fri Jan 31 18:37:30 2025 -0500 [Feature] Fix guided decoding blocking bitmask memcpy (vllm-project#12563) **[Guided decoding performance optimization]** Sending the guided decoding bitmask in xgrammar to the GPU (`self.token_bitmask.to(scores.device)`) is a blocking operation that prevents the CPU from pre-launching the sampler kernels. The CPU waits until decode is complete, then copies the bitmask over. This PR changes the operation to async via setting `non-blocking=True`. (Current) The CPU is blocked on a `cudaStreamSynchronize` and only pre-empts the sampling kernels after bitmask application. Below is the Nsys profile for one decode phase from Llama 3.1 8B. ![image](https://github.com/user-attachments/assets/8997eae1-b822-4f52-beb8-ef19a7c6b824) With the optimization, this is no longer the case: ![image](https://github.com/user-attachments/assets/6d5ea83f-f169-4f98-a8c1-41c719b3e1e7) --------- Signed-off-by: Ryan N <[email protected]> commit eb5741a Author: Tyler Michael Smith <[email protected]> Date: Fri Jan 31 18:29:11 2025 -0500 [Kernel][Quantization] Integrate block-quantized CUTLASS kernels for DeepSeekV3 (vllm-project#12587) Integrates the block-quantized kernels introduced in vllm-project#11868 for use in linear layers. Signed-off-by: Tyler Michael Smith <[email protected]> commit 145c2ff Author: Robert Shaw <[email protected]> Date: Fri Jan 31 18:28:47 2025 -0500 [Bugfix] Revert MoE Triton Config Default (vllm-project#12629) SUMMARY: * previous PR for pulling in block configs also changed defaults (https://github.com/vllm-project/vllm/pull/11589/files) for FP8 * this broke L4 MoE since there was not enough SHM for the default configuration * this reverts the non-block example to the default Signed-off-by: [email protected] <[email protected]> commit 415f194 Author: Kevin H. Luu <[email protected]> Date: Fri Jan 31 13:39:36 2025 -0800 [release] Add input step to ask for Release version (vllm-project#12631) Instead of having to create a new build with release version put in as env var. commit 4251506 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 21:26:13 2025 +0000 fixes Signed-off-by: Lucas Wilkinson <[email protected]> commit c9d72cb Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 21:17:23 2025 +0000 more cleanup Signed-off-by: Lucas Wilkinson <[email protected]> commit 3cdd2ce Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 21:16:42 2025 +0000 cleanup Signed-off-by: Lucas Wilkinson <[email protected]> commit 89003c4 Author: Chen Zhang <[email protected]> Date: Sat Feb 1 05:13:04 2025 +0800 [v1][Bugfix] Add extra_keys to block_hash for prefix caching (vllm-project#12603) This pr adds extra key to block hash, to generate different hash value for two blocks with the same token string but different extra_keys in their parent blocks. For example, it can generate different hash value for the second block of the following two requests: ```python request1 = make_request( request_id=0, prompt_token_ids=[_ for _ in range(6)], mm_positions=[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 3 }], mm_hashes=["hash1", "hash2"], ) request2 = make_request( request_id=1, prompt_token_ids=[_ for _ in range(6)], mm_positions=[{ "offset": 0, "length": 3 }, { "offset": 3, "length": 3 }], mm_hashes=["hash3", "hash2"], ) ``` --------- Signed-off-by: Chen Zhang <[email protected]> commit f51cbe0 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 21:04:22 2025 +0000 review comments Signed-off-by: Lucas Wilkinson <[email protected]> commit 3d12a04 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 20:45:14 2025 +0000 working but messy Signed-off-by: Lucas Wilkinson <[email protected]> commit 60bcef0 Author: Cody Yu <[email protected]> Date: Fri Jan 31 12:30:46 2025 -0800 [Docs][V1] Prefix caching design (vllm-project#12598) - Create v1 design document section in docs. - Add prefix caching design doc. @WoosukKwon @ywang96 --------- Signed-off-by: Cody Yu <[email protected]> commit 847f883 Author: Cody Yu <[email protected]> Date: Fri Jan 31 12:30:33 2025 -0800 [Git] Automatically sign-off commits (vllm-project#12595) It's very annoying when I forgot to add `-s` in `git commit` to sign-off, because I then need to `git rebase HEAD~1 --signoff` and `git push -f` to fix the DCO. This PR adds a hook to sign off commits automatically when `-s` is missing to solve this problem. The only change from the user side is now users have to install 2 hooks, so instead of just ``` pre-commit install ``` Now we need to ``` pre-commit install --hook-type pre-commit --hook-type commit-msg ``` Note that even if users still only install the pre-commit hook, they won't get any error in `git commit`. Just the sign-off hook won't run. cc @hmellor @youkaichao --------- Signed-off-by: Cody Yu <[email protected]> commit 325f679 Author: Robert Shaw <[email protected]> Date: Fri Jan 31 15:06:39 2025 -0500 [BugFix] Fix Torch.Compile For DeepSeek (vllm-project#12594) Co-authored-by: simon-mo <[email protected]> commit 548ec44 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 19:13:22 2025 +0000 simon changes Signed-off-by: Lucas Wilkinson <[email protected]> commit a57cd3d Merge: 076cbe5 cabaf4e Author: simon-mo <[email protected]> Date: Fri Jan 31 07:52:26 2025 +0000 Merge branch 'main' of github.com:vllm-project/vllm into mla-fp8 commit 076cbe5 Merge: 0ccbcce a1fc18c Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 23:31:41 2025 -0500 Merge branch 'main' into mla-fp8 commit 0ccbcce Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 04:29:17 2025 +0000 deepseek v3 support Signed-off-by: Lucas Wilkinson <[email protected]> commit 645622c Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 03:08:36 2025 +0000 cleanup Signed-off-by: Lucas Wilkinson <[email protected]> commit 2d61054 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 03:03:07 2025 +0000 cleanup Co-authored-by: Alexander Matveev <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit f2b2500 Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 02:47:05 2025 +0000 Fix TP > 1 cuda graphs Co-authored-by: Alexander Matveev <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 433322b Author: Lucas Wilkinson <[email protected]> Date: Fri Jan 31 02:26:11 2025 +0000 Revert "add cuda graph support" Signed-off-by: Lucas Wilkinson <[email protected]> commit 31c34bf Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 23:06:09 2025 +0000 ci fix Signed-off-by: Lucas Wilkinson <[email protected]> commit 54ba87d Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 21:23:09 2025 +0000 add cuda graph support Signed-off-by: Lucas Wilkinson <[email protected]> commit 5afc1bf Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 20:58:53 2025 +0000 fix mypy Signed-off-by: Lucas Wilkinson <[email protected]> commit cfb2d26 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 19:42:36 2025 +0000 fix mypy Signed-off-by: Lucas Wilkinson <[email protected]> commit 37e39f4 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 18:04:58 2025 +0000 fix failing test Signed-off-by: Lucas Wilkinson <[email protected]> commit 0881475 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 17:18:55 2025 +0000 disable MLA for v3 for now Signed-off-by: Lucas Wilkinson <[email protected]> commit 4a46014 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 11:12:48 2025 -0500 Update vllm/attention/backends/mla/utils.py Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 09d814c Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 15:11:58 2025 +0000 review comments Signed-off-by: Lucas Wilkinson <[email protected]> commit 8bdc14a Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 14:09:46 2025 +0000 review comments Signed-off-by: Lucas Wilkinson <[email protected]> commit d27826d Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 08:51:42 2025 -0500 Update vllm/config.py Co-authored-by: Zhuohan Li <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 7487429 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 04:00:26 2025 +0000 renaming for consistency Signed-off-by: Lucas Wilkinson <[email protected]> commit 634eee6 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 03:52:59 2025 +0000 review comments Signed-off-by: Lucas Wilkinson <[email protected]> commit 31b802c Author: Lucas Wilkinson <[email protected]> Date: Wed Jan 29 22:51:37 2025 -0500 Update vllm/attention/backends/mla/utils.py Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit 068e672 Author: Lucas Wilkinson <[email protected]> Date: Wed Jan 29 22:46:43 2025 -0500 Update utils.py Co-authored-by: Michael Goin <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> commit f2cac91 Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 03:11:43 2025 +0000 more cleanups Signed-off-by: Lucas Wilkinson <[email protected]> commit c34e5ca Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 03:02:58 2025 +0000 fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0 Signed-off-by: Lucas Wilkinson <[email protected]> commit 27ad92c Author: Lucas Wilkinson <[email protected]> Date: Thu Jan 30 02:29:40 2025 +0000 squashed commits Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: simon-mo <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]>
This pr adds extra key to block hash, to generate different hash value for two blocks with the same token string but different extra_keys in their parent blocks. For example, it can generate different hash value for the second block of the following two requests: