Skip to content

Commit

Permalink
fix bh bias diagonal handling
Browse files Browse the repository at this point in the history
  • Loading branch information
timt51 committed Sep 8, 2024
1 parent 3b267ff commit b9ec215
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ struct Mask {
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += 0; // alibi_slope * col_idx;
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
Expand All @@ -178,10 +178,10 @@ struct Mask {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // alibi_slope * col_idx;
tensor(make_coord(i, mi), make_coord(j, nj)) += ((col_idx == (col_idx_limit_right - 1)) ? 0 : alibi_slope);

} else {
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);

}
}
Expand Down

0 comments on commit b9ec215

Please sign in to comment.