diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 3cdfee5516a8b..21345f79a3d5b 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -347,6 +347,7 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, for (uint32_t k = 0; k < min(K * M, 32 * 1024); k += THRDS * WvPrGrp * A_CHUNK) { uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (k_in >= min(K * M, 32 * 1024)) break; ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; } __syncthreads();