From aae08249acca69060d0a8220cab920e00520932c Mon Sep 17 00:00:00 2001 From: alexm-nm <59768536+alexm-nm@users.noreply.github.com> Date: Wed, 24 Apr 2024 13:35:01 -0400 Subject: [PATCH] [Bugfix] Fix marlin kernel crash on H100 (#4218) This PR addresses the Marlin kernel H100 crash that was reported here: neuralmagic#187. The reason for the crash was the inline PTX assembly that introduced the async_copy with streaming behavior. The solution is to use the more standard PTX for async_copy (without the fractional L2 policy for "evict_first"). There is no performance difference between standard async_copy PTX and the previous one. --- .../quantization/marlin/marlin_cuda_kernel.cu | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/csrc/quantization/marlin/marlin_cuda_kernel.cu b/csrc/quantization/marlin/marlin_cuda_kernel.cu index cf1b0afdec8b4..002a70001885d 100644 --- a/csrc/quantization/marlin/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/marlin_cuda_kernel.cu @@ -67,20 +67,13 @@ __device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, "r"(smem), "l"(glob_ptr), "n"(BYTES)); } -// Asynchronous global->shared copy with a cache hint indicating that the values -// may be evicted immediately; used for quantized weights B, which are only -// accessed precisely once and should thus not pollute the L2 cache which we -// need for inputs A and outputs C. -__device__ inline void cp_async4_stream(void *smem_ptr, const void *glob_ptr) { +// Asynchronous global->shared copy +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { const int BYTES = 16; uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .b64 p;\n" - " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" - " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); + asm volatile("{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(BYTES)); } // Async copy fence. @@ -448,14 +441,14 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk int4 *sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { - cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); B_ptr[i] += b_gl_rd_delta_o; } // Only fetch scales if this tile starts a new group if (group_blocks != -1 && pipe % (group_blocks / thread_k_blocks) == 0) { int4 *sh_s_stage = sh_s + s_sh_stage * pipe; if (s_sh_wr_pred) - cp_async4_stream(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s_stage[s_sh_wr], &s[s_gl_rd]); s_gl_rd += s_gl_rd_delta; } } @@ -750,7 +743,7 @@ Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk // write-out if (group_blocks == -1 && last) { if (s_sh_wr_pred) - cp_async4_stream(&sh_s[s_sh_wr], &s[s_gl_rd]); + cp_async4(&sh_s[s_sh_wr], &s[s_gl_rd]); cp_async_fence(); } thread_block_reduce();