Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] port sgl moe_align_block_size kernels #12574

Merged
merged 1 commit into from
Feb 3, 2025

Conversation

chenyang78
Copy link
Contributor

@chenyang78 chenyang78 commented Jan 30, 2025

sgl_moe_align_block_size is based on:

sgl-project/sglang@ded9fcd

moe_align_block_size is based on:

sgl-project/sglang@ba5112f

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@chenyang78 chenyang78 force-pushed the moe-align branch 2 times, most recently from 407a559 to ee02e4c Compare January 31, 2025 01:46
@@ -197,6 +197,72 @@ __global__ void moe_align_block_size_global_mem_kernel(
}
}

// temporarily adapted from
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what was changed? just the function name? everything else looks the same to me, if thats is the case we should just say taken from instead of temporarily adapted from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. Thanks.

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 1, 2025
@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Feb 1, 2025

  Tokens/Sec Output Tok/Sec Requests/Sec
after 1132.56 849.42 0.28
before 1112.99 834.74 0.28
  • before
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             _fwd_grouped_kernel_stage1         0.00%       0.000us         0.00%       0.000us       0.000us        3.362s        56.46%        3.362s     551.093us          6100  
void cutlass::device_kernel<vllm::cutlass_3x_gemm_fp...         0.00%       0.000us         0.00%       0.000us       0.000us     740.156ms        12.43%     740.156ms      23.980us         30866  
                                       fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us     582.178ms         9.78%     582.178ms      49.691us         11716  
void vllm::moe::moe_align_block_size_kernel<int, uns...         0.00%       0.000us         0.00%       0.000us       0.000us     317.825ms         5.34%     317.825ms      54.255us          5858  
                                 _w8a8_block_fp8_matmul         0.00%       0.000us         0.00%       0.000us       0.000us     167.413ms         2.81%     167.413ms      27.173us          6161  
                            vllm::inplace_fused_experts         0.21%      13.429ms         0.36%      23.469ms     404.633us     126.093ms         2.12%     144.540ms       2.492ms            58  
  • after
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                             _fwd_grouped_kernel_stage1         0.00%       0.000us         0.00%       0.000us       0.000us        3.350s        58.79%        3.350s     549.214us          6100  
void cutlass::device_kernel<vllm::cutlass_3x_gemm_fp...         0.00%       0.000us         0.00%       0.000us       0.000us     734.721ms        12.89%     734.721ms      23.804us         30866  
                                       fused_moe_kernel         0.00%       0.000us         0.00%       0.000us       0.000us     577.712ms        10.14%     577.712ms      49.310us         11716  
                                 _w8a8_block_fp8_matmul         0.00%       0.000us         0.00%       0.000us       0.000us     166.231ms         2.92%     166.231ms      26.981us          6161  
                            vllm::inplace_fused_experts         0.22%      13.736ms         0.38%      23.483ms     404.874us     124.716ms         2.19%     139.846ms       2.411ms            58  
void vllm::cross_device_reduce_1stage<__nv_bfloat16,...         0.00%       0.000us         0.00%       0.000us       0.000us     106.168ms         1.86%     106.168ms       8.632us         12300  
void vllm::moe::sgl_moe_align_block_size_kernel<int>...         0.00%       0.000us         0.00%       0.000us       0.000us      65.403ms         1.15%      65.403ms      11.165us          5858
void at::native::sbtopk::gatherTopK<c10::BFloat16, u...         0.00%       0.000us         0.00%       0.000us       0.000us      62.487ms         1.10%      62.487ms       5.387us         11600 

LM Eval:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match||0.9507|±  |0.0060|
|     |       |strict-match    |     5|exact_match||0.9484|±  |0.0061|

@@ -18,6 +18,9 @@

logger = init_logger(__name__)

enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we should move this to ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON envs.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks.

Copy link
Contributor

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a couple nits

@robertgshaw2-redhat
Copy link
Collaborator

LGTM. Thanks for the work. Can you address Lucas's comments in a follow up?

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) February 1, 2025 20:47
sgl_moe_align_block_size is based on:

sgl-project/sglang@ded9fcd

moe_align_block_size is based on:

sgl-project/sglang@ba5112f
Signed-off-by: Yang Chen <[email protected]>
auto-merge was automatically disabled February 2, 2025 10:50

Head branch was pushed to by a user without write access

@chenyang78
Copy link
Contributor Author

LGTM. Thanks for the work. Can you address Lucas's comments in a follow up?

Thanks for the review! Seems CI failed with some timeout error.


[2025-02-02T09:34:34Z] FAILED spec_decode/e2e/test_logprobs.py::test_logprobs_equality[-1-1-1-7-8-test_llm_kwargs0-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0] - requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: c5fa32cf-b0ce-4906-ba84-2d599927c376)')
--
  | [2025-02-02T09:34:34Z] FAILED spec_decode/e2e/test_logprobs.py::test_logprobs_equality[-1-1-1-7-8-test_llm_kwargs1-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0] - requests.exceptions.ReadTimeout: (ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 67406ecd-dcc0-4dec-aa50-9ac8d6a12f68)')

BTW, I addressed Lucas's comments in the PR. My new change disabled the auto-merge label. Could you help set it again when you get a chance? Thanks!

@DarkLight1337 DarkLight1337 merged commit 95460fc into vllm-project:main Feb 3, 2025
70 checks passed
sahelib25 pushed a commit to krai/vllm that referenced this pull request Feb 3, 2025
sgl_moe_align_block_size is based on:


sgl-project/sglang@ded9fcd

moe_align_block_size is based on:


sgl-project/sglang@ba5112f

Signed-off-by: Yang Chen <[email protected]>
yessenzhar pushed a commit to deepinfra/vllm that referenced this pull request Feb 3, 2025
sgl_moe_align_block_size is based on:

sgl-project/sglang@ded9fcd

moe_align_block_size is based on:

sgl-project/sglang@ba5112f

Signed-off-by: Yang Chen <[email protected]>
fxmarty-amd pushed a commit to fxmarty-amd/vllm that referenced this pull request Feb 7, 2025
sgl_moe_align_block_size is based on:

sgl-project/sglang@ded9fcd

moe_align_block_size is based on:

sgl-project/sglang@ba5112f

Signed-off-by: Yang Chen <[email protected]>
Signed-off-by: Felix Marty <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants