Skip to content

Commit

Permalink
Update flash_attention kernel from 2.3.6 to 2.5.5 (pytorch#118935)
Browse files Browse the repository at this point in the history
# Summary
Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5).

The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization.

Pull Request resolved: pytorch#118935
Approved by: https://github.com/cpuhrsch
  • Loading branch information
drisspg authored and pytorchmergebot committed Mar 4, 2024
1 parent d49864f commit 2e6c08a
Show file tree
Hide file tree
Showing 42 changed files with 2,112 additions and 2,301 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,12 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA AND NOT MSVC" OFF)

# We are currenlty not using alibi attention for Flash
# So we disable this feature by default
# We dont currently document this feature because we don't
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)

# CAVEAT: Again, do not check USE_ROCM here
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option(
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
#include <ATen/ops/scalar_tensor.h>
#include <ATen/ops/scaled_dot_product_attention.h>
#include <ATen/ops/split_native.h>
#include <ATen/ops/narrow_native.h>
#include <ATen/ops/zeros.h>
#endif

Expand All @@ -65,7 +64,6 @@
#include <ATen/native/transformers/attention.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/native/transformers/cuda/sdp_utils.h>
#include <ATen/native/transformers/sdp_utils_cpp.h>

Expand Down Expand Up @@ -852,6 +850,7 @@ _flash_attention_forward(
// of the tensor. This is useful for kv cache scenarios but for now
// we will not support in this PR.
c10::optional<Tensor> seqused_k = c10::nullopt;
c10::optional<Tensor> alibi_slopes = c10::nullopt;

// We are going to have two paths:
// 1. The standard MHA path for dense tensors
Expand Down Expand Up @@ -880,6 +879,7 @@ _flash_attention_forward(
cumulative_sequence_length_q.value(),
cumulative_sequence_length_k.value(),
seqused_k, /*seqused_k*/
alibi_slopes, /*alibi_slopes*/
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
Expand All @@ -905,6 +905,7 @@ _flash_attention_forward(
key,
value,
out,
alibi_slopes,
dropout_p,
softmax_scale,
is_causal,
Expand Down
26 changes: 22 additions & 4 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <string_view>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <cstdint>
#include <type_traits>
Expand Down Expand Up @@ -41,9 +42,8 @@
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
#endif
namespace at {

namespace native {
namespace at::native {

std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
const Tensor& grad_out,
Expand Down Expand Up @@ -74,6 +74,21 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// The kernel computes irregardless we will drop for this functions return
Tensor grad_softmax;

// Currently unused args:
c10::optional<at::Tensor> alibi_slopes{c10::nullopt};

bool determinisitic{false};
auto& ctx = at::globalContext();
if (ctx.deterministicAlgorithms()) {
if (ctx.deterministicAlgorithmsWarnOnly()) {
TORCH_WARN_ONCE(
"Flash Attention defaults to a non-deterministic algorithm. ",
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
} else {
determinisitic = true;
}
}

// We check the whether the cumulative_sequence_length_q is defined
// in order to determine whether we are using varlen or dense forward
if (cumulative_sequence_length_q.defined()) {
Expand All @@ -90,6 +105,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dv,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
alibi_slopes,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
Expand All @@ -98,6 +114,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
-1, /*window_size_left*/
-1, /*window_size_right*/
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(dQuery, dKey, dValue);
Expand All @@ -113,11 +130,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
dq,
dk,
dv,
alibi_slopes,
dropout_p,
softmax_scale,
is_causal,
-1, /*window_size_left*/
-1, /*window_size_right*/
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
Expand Down Expand Up @@ -630,5 +649,4 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
}

} // namespace native
} // namespace at
} // namespace at::native
74 changes: 74 additions & 0 deletions aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include <cmath>

#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>

#include <ATen/native/transformers/cuda/flash_attn/utils.h>

namespace pytorch_flash {

using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_causal>
struct Alibi {

const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;

__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope)
, max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
};


template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}

};

} // namespace pytorch_flash
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ struct BlockInfo {
}

template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

Expand Down
96 changes: 96 additions & 0 deletions aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/

#pragma once

#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/utils.h>

namespace pytorch_flash {

using namespace cute;

struct Dropout {

const unsigned long long seed, offset;
const uint8_t p_dropout_in_uint8_t;

__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
const uint8_t p_dropout_in_uint8_t,
const int bid, const int hid, const int tid, const int nheads)
: seed(seed)
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
}

template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
};
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
uint2 rowcol = make_uint2(block_row_start, block_col_start);
#pragma unroll
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if (!encode_dropout_in_sign_bit
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
uint16_t rnd_16[16];
#pragma unroll
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
#pragma unroll
for (int j = 0; j < 2; j++) {
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for (int i = 0; i < 4; i++) {
uint32_t mask;
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
tensor_uint32(i) &= mask;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
} else {
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int i = 0; i < 8; i++) {
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
}
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}

};

} // namespace pytorch_flash
30 changes: 21 additions & 9 deletions aten/src/ATen/native/transformers/cuda/flash_attn/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
#pragma once

#include <cuda.h>
#include <vector>

#include <ATen/cuda/PhiloxUtils.cuh>

namespace pytorch_flash{

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif

#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
namespace pytorch_flash {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
Expand Down Expand Up @@ -96,7 +98,12 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
int * __restrict__ cache_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand Down Expand Up @@ -126,6 +133,9 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved;

int num_splits; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -165,14 +175,16 @@ struct Flash_bwd_params : public Flash_fwd_params {

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;

bool deterministic;
index_t dq_accum_split_stride;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);

} // namespace pytorch_flash
Loading

0 comments on commit 2e6c08a

Please sign in to comment.