Skip to content

Commit

Permalink
try disabling Col_idx_only if alibi
Browse files Browse the repository at this point in the history
  • Loading branch information
timt51 committed Sep 7, 2024
1 parent 5996687 commit 3b267ff
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct Mask {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
static constexpr bool Col_idx_only = !Has_alibi && !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Col_idx_only) {
Expand Down

0 comments on commit 3b267ff

Please sign in to comment.