Skip to content

Commit 14b704d

Browse files
committed
fix diagonal bias
1 parent cfca2e8 commit 14b704d

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

csrc/flash_attn/src/mask.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ struct Mask {
153153
for (int mi = 0; mi < size<0>(tensor); ++mi) {
154154
// No causal, no local
155155
if constexpr (Has_alibi) {
156-
tensor(mi, make_coord(j, nj)) += alibi_slope; // * col_idx;
156+
tensor(mi, make_coord(j, nj)) += (col_idx == 0 ? 0 : alibi_slope); // alibi_slope * col_idx;
157157
}
158158
if constexpr (!Is_even_MN) {
159159
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
@@ -178,9 +178,10 @@ struct Mask {
178178
const int col_idx = col_idx_base + j;
179179
if constexpr (Has_alibi) {
180180
if constexpr (Is_causal) {
181-
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope; // * col_idx;
181+
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // alibi_slope * col_idx;
182+
182183
} else {
183-
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope; // * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
184+
tensor(make_coord(i, mi), make_coord(j, nj)) += (col_idx == row_idx ? 0 : alibi_slope); // -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
184185

185186
}
186187
}

0 commit comments

Comments
 (0)