diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 361ea2380..c21933e45 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -153,7 +153,7 @@ struct Mask { for (int mi = 0; mi < size<0>(tensor); ++mi) { // No causal, no local if constexpr (Has_alibi) { - tensor(mi, make_coord(j, nj)) += alibi_slope; // * col_idx; + tensor(mi, make_coord(j, nj)) += (col_idx == 0 ? 0 : alibi_slope); // alibi_slope * col_idx; } if constexpr (!Is_even_MN) { if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } @@ -178,9 +178,10 @@ struct Mask { const int col_idx = col_idx_base + j; if constexpr (Has_alibi) { if constexpr (Is_causal) { - tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope; // * col_idx; + tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // alibi_slope * col_idx; + } else { - tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope; // * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); } }