diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index fd68cec12..596b69516 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -100,7 +100,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi #pragma unroll for (int m = 0; m < size<1>(tOgO); ++m) { const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = -INFINITY; } } return; } @@ -545,7 +545,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons #pragma unroll for (int m = 0; m < size<1>(tOgOaccum); ++m) { const int row = get<0>(tOcO(0, m, 0)); - if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = -INFINITY; } } return; } @@ -1141,9 +1141,7 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } SumOp sum_op; lse_sum = Allreduce::run(lse_sum, sum_op); - // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise - // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. - ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + ElementAccum lse_logsum = logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } // Store the scales exp(lse - lse_logsum) in shared memory. @@ -1151,7 +1149,9 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; const int col = tidx / kRowsPerLoadTranspose; - if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + // For the case where all local lse == -INFINITY, we want to explicitly set sLSE to 0. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = (lse_sum == 0.f || lse_sum != lse_sum) ? 0 : expf(lse_accum(l) - lse_logsum); } } __syncthreads();