Skip to content

Commit

Permalink
fix diagonal bias
Browse files Browse the repository at this point in the history
  • Loading branch information
timt51 committed Sep 7, 2024
1 parent cfca2e8 commit 14b704d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ struct Mask {
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope; // * col_idx;
tensor(mi, make_coord(j, nj)) += (col_idx == 0 ? 0 : alibi_slope); // alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
Expand All @@ -178,9 +178,10 @@ struct Mask {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope; // * col_idx;
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // alibi_slope * col_idx;

} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope; // * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);

}
}
Expand Down

0 comments on commit 14b704d

Please sign in to comment.