From d0feea31c70f9540a8993c4e96103a03cd935416 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Sat, 8 Mar 2025 00:53:38 +0800 Subject: [PATCH] [Kernel] optimize performance of gptq marlin kernel when n is small (#14138) Signed-off-by: Jinzhen Lin --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 62 ++++++++++++++----- csrc/torch_bindings.cpp | 3 +- tests/kernels/test_marlin_gemm.py | 16 +++-- vllm/_custom_ops.py | 5 +- vllm/envs.py | 5 ++ .../layers/quantization/utils/marlin_utils.py | 32 ++++++++++ 6 files changed, 99 insertions(+), 24 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 7c33fea93d6ae..72627df24b9af 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -538,6 +538,7 @@ __global__ void Marlin( int prob_n, // output dimension n int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce bool use_fp32_reduce // whether to use fp32 global reduce ) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the @@ -1542,7 +1543,17 @@ __global__ void Marlin( i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh_red[c_sh_rd]; + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } @@ -1644,7 +1655,7 @@ __global__ void Marlin( } cp_async_fence(); } else { - if (last) { + if (last || use_atomic_add) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } @@ -1664,7 +1675,7 @@ __global__ void Marlin( } } else { - if (last) { + if (last || use_atomic_add) { cp_async_wait<0>(); __syncthreads(); if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1703,8 +1714,8 @@ __global__ void Marlin( } } - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice barrier_acquire(&locks[slice_col], slice_idx); if (use_fp32_reduce) { global_reduce_fp32(slice_idx == 0, last); @@ -1713,7 +1724,8 @@ __global__ void Marlin( } barrier_release(&locks[slice_col], last); } - if (last) // only the last block in a slice actually writes the result + if (last || use_atomic_add) + // only the last block in a slice actuallywrites the result write_result(); slice_row = 0; slice_col_par++; @@ -1768,7 +1780,8 @@ __global__ void Marlin( HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \ <<>>( \ A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \ + use_fp32_reduce); \ } \ } @@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) { + int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, @@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& workspace, vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, + bool is_k_full, bool has_zp, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); if (has_zp) { @@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + torch::Tensor c; + if (use_atomic_add) { + c = torch::zeros({size_m, size_n}, options); + } else { + c = torch::empty({size_m, size_n}, options); + } + + torch::Tensor a_tmp; + bool has_act_order = g_idx.size(0) != 0; + if (has_act_order) { + a_tmp = torch::empty({size_m, size_k}, options); + } else { + a_tmp = torch::empty({0}, options); + } // Alloc C tmp buffer that is going to be used for the global reduce + torch::Tensor c_tmp; int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par); int reduce_n = size_n; auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device()); - if (!use_fp32_reduce) { + if (use_fp32_reduce) { + c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); + } else { reduce_max_m = 0; reduce_n = 0; + c_tmp = torch::empty({0}, options_fp32); } - torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); // thread_k: `k` size of a thread_tile in `weights` (can usually be left as // auto -1) @@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Detect groupsize and act_order int num_groups = -1; int group_size = -1; - bool has_act_order = g_idx.size(0) != 0; int rank = b_scales.sizes().size(); TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2"); @@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); + thread_k, thread_n, sms, marlin::max_par, use_atomic_add, + use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), @@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float); + thread_k, thread_n, sms, marlin::max_par, use_atomic_add, + use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fe7a674bb03c0..b06b12220793f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -272,7 +272,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " "int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor", + "bool has_zp, bool use_atomic_add, bool use_fp32_reduce, " + "bool is_zp_float) -> Tensor", {stride_tag}); // conditionally compiled so impl registration is in source file diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index b96aca06cdff3..c0cf5b099f993 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -34,6 +34,7 @@ ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] +USE_ATOMIC_ADD_OPTS = [False, True] USE_FP32_REDUCE_OPTS = [False, True] MARLIN_K_CHUNKS = [128] @@ -194,6 +195,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) +@pytest.mark.parametrize("use_atomic_add", USE_ATOMIC_ADD_OPTS) @pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) def test_gptq_marlin_gemm( k_chunk, @@ -203,6 +205,7 @@ def test_gptq_marlin_gemm( mnk_factors, act_order, is_k_full, + use_atomic_add, use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors @@ -228,12 +231,12 @@ def test_gptq_marlin_gemm( workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL) - opcheck( - torch.ops._C.gptq_marlin_gemm, - (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], - a_input.shape[1], is_k_full, False, use_fp32_reduce, False), - test_utils=DEFAULT_OPCHECK_TEST_UTILS) + opcheck(torch.ops._C.gptq_marlin_gemm, + (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, + workspace.scratch, quant_type.id, a_input.shape[0], + b_weight.shape[1], a_input.shape[1], is_k_full, False, + use_atomic_add, use_fp32_reduce, False), + test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, @@ -249,6 +252,7 @@ def test_gptq_marlin_gemm( a_input.shape[1], is_k_full=is_k_full, has_zp=False, + use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3c822028426ee..1f362a45aa7d3 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -301,6 +301,7 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, size_k: torch.SymInt, is_k_full: bool, has_zp: bool = False, + use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @@ -713,12 +714,14 @@ def gptq_marlin_gemm(a: torch.Tensor, size_k: int, is_k_full: bool, has_zp: bool = False, + use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, - has_zp, use_fp32_reduce, is_zp_float) + has_zp, use_atomic_add, + use_fp32_reduce, is_zp_float) # fp8 marlin diff --git a/vllm/envs.py b/vllm/envs.py index 2489affbcbd2f..187d28b2d6d30 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -95,6 +95,7 @@ VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_MARLIN_USE_ATOMIC_ADD: bool = False def get_default_cache_root(): @@ -630,6 +631,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # Whether to use S3 path for model loading in CI via RunAI Streamer "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", + + # Whether to use atomicAdd reduce in gptq/awq marlin kernel. + "VLLM_MARLIN_USE_ATOMIC_ADD": + lambda: os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1", } # end-env-vars-definition diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 80416c1bc6ebc..d1fb52ae09def 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -5,6 +5,7 @@ import numpy import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase from vllm.platforms import current_platform @@ -290,6 +291,23 @@ def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return output +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD or device.type != "cuda": + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + return False + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + return max(m, 64) * n < 64 * 2048 and k >= 2048 + + def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, @@ -307,6 +325,12 @@ def apply_gptq_marlin_linear( reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + output = ops.gptq_marlin_gemm(reshaped_x, weight, weight_scale, @@ -320,6 +344,7 @@ def apply_gptq_marlin_linear( size_k=input_size_per_partition, is_k_full=is_k_full, has_zp=False, + use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False) @@ -345,6 +370,12 @@ def apply_awq_marlin_linear( reshaped_x = input.reshape(-1, input.shape[-1]) out_shape = input.shape[:-1] + (output_size_per_partition, ) + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + output = ops.gptq_marlin_gemm(reshaped_x, weight, weight_scale, @@ -358,6 +389,7 @@ def apply_awq_marlin_linear( size_k=input_size_per_partition, is_k_full=True, has_zp=True, + use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False)