diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 13c2688af..0f1f3763a 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -139,7 +139,7 @@ struct Mask { // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); // Do we need both row and column indices, or just column incides? - static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + static constexpr bool Col_idx_only = !Has_alibi && !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; if constexpr (Col_idx_only) {