From 2e6c08a14be88e663159f0d0c206e6621d00983d Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 4 Mar 2024 17:36:22 +0000 Subject: [PATCH] Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935) # Summary Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5). The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935 Approved by: https://github.com/cpuhrsch --- CMakeLists.txt | 6 + .../native/transformers/cuda/attention.cu | 5 +- .../transformers/cuda/attention_backward.cu | 26 +- .../transformers/cuda/flash_attn/alibi.h | 74 ++ .../transformers/cuda/flash_attn/block_info.h | 4 +- .../transformers/cuda/flash_attn/dropout.h | 96 ++ .../transformers/cuda/flash_attn/flash.h | 30 +- .../cuda/flash_attn/flash_api.cpp | 400 +++++--- .../transformers/cuda/flash_attn/flash_api.h | 18 +- .../cuda/flash_attn/flash_bwd_kernel.h | 877 ++---------------- .../flash_attn/flash_bwd_launch_template.h | 314 +++---- .../flash_attn/flash_bwd_preprocess_kernel.h | 377 ++++++++ .../cuda/flash_attn/flash_fwd_kernel.h | 417 ++++----- .../flash_attn/flash_fwd_launch_template.h | 132 +-- .../cuda/flash_attn/kernel_traits.h | 107 +-- .../cuda/flash_attn/kernel_traits_sm90.h | 161 ---- .../kernels/flash_bwd_hdim128_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim128_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim160_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim160_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim192_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim192_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim224_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim224_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim256_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim256_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim32_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim32_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim64_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim64_fp16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim96_bf16_sm80.cu | 4 +- .../kernels/flash_bwd_hdim96_fp16_sm80.cu | 4 +- .../flash_attn/kernels/generate_kernels.py | 4 +- .../transformers/cuda/flash_attn/mask.h | 213 +++++ .../transformers/cuda/flash_attn/philox.cuh | 120 +-- .../transformers/cuda/flash_attn/rotary.h | 152 +++ .../transformers/cuda/flash_attn/softmax.h | 259 ++---- .../cuda/flash_attn/static_switch.h | 43 +- .../transformers/cuda/flash_attn/utils.h | 220 +---- .../native/transformers/cuda/sdp_utils.cpp | 18 +- .../transformers/hip/flash_attn/flash_api.hip | 18 +- test/test_transformers.py | 258 ++++-- 42 files changed, 2112 insertions(+), 2301 deletions(-) create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h delete mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/mask.h create mode 100644 aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h diff --git a/CMakeLists.txt b/CMakeLists.txt index b0679d28da91eb..9b8e683be8264f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -744,6 +744,12 @@ cmake_dependent_option( Will be disabled if not supported by the platform" ON "USE_CUDA AND NOT MSVC" OFF) +# We are currenlty not using alibi attention for Flash +# So we disable this feature by default +# We dont currently document this feature because we don't +# Suspect users building from source will need this +add_definitions(-DFLASHATTENTION_DISABLE_ALIBI) + # CAVEAT: Again, do not check USE_ROCM here # Flash Attention2 will error while building for sm52 while Mem Eff Attention won't cmake_dependent_option( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 900defaa763660..55de97ad223a78 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -50,7 +50,6 @@ #include #include #include -#include #include #endif @@ -65,7 +64,6 @@ #include #include #include -#include #include #include @@ -852,6 +850,7 @@ _flash_attention_forward( // of the tensor. This is useful for kv cache scenarios but for now // we will not support in this PR. c10::optional seqused_k = c10::nullopt; + c10::optional alibi_slopes = c10::nullopt; // We are going to have two paths: // 1. The standard MHA path for dense tensors @@ -880,6 +879,7 @@ _flash_attention_forward( cumulative_sequence_length_q.value(), cumulative_sequence_length_k.value(), seqused_k, /*seqused_k*/ + alibi_slopes, /*alibi_slopes*/ max_seqlen_batch_q, max_seqlen_batch_k, dropout_p, @@ -905,6 +905,7 @@ _flash_attention_forward( key, value, out, + alibi_slopes, dropout_p, softmax_scale, is_causal, diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index cf8d543f12122c..c829f45a6f4a3f 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -1,3 +1,4 @@ +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -41,9 +42,8 @@ #include #include #endif -namespace at { -namespace native { +namespace at::native { std::tuple _flash_attention_backward( const Tensor& grad_out, @@ -74,6 +74,21 @@ std::tuple _flash_attention_backward( // The kernel computes irregardless we will drop for this functions return Tensor grad_softmax; + // Currently unused args: + c10::optional alibi_slopes{c10::nullopt}; + + bool determinisitic{false}; + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "Flash Attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } else { + determinisitic = true; + } + } + // We check the whether the cumulative_sequence_length_q is defined // in order to determine whether we are using varlen or dense forward if (cumulative_sequence_length_q.defined()) { @@ -90,6 +105,7 @@ std::tuple _flash_attention_backward( dv, cumulative_sequence_length_q, cumulative_sequence_length_k, + alibi_slopes, max_seqlen_batch_q, max_seqlen_batch_k, dropout_p, @@ -98,6 +114,7 @@ std::tuple _flash_attention_backward( is_causal, -1, /*window_size_left*/ -1, /*window_size_right*/ + determinisitic, philox_seed, philox_offset); return std::make_tuple(dQuery, dKey, dValue); @@ -113,11 +130,13 @@ std::tuple _flash_attention_backward( dq, dk, dv, + alibi_slopes, dropout_p, softmax_scale, is_causal, -1, /*window_size_left*/ -1, /*window_size_right*/ + determinisitic, philox_seed, philox_offset); return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue)); @@ -630,5 +649,4 @@ std::tuple _scaled_dot_product_e grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias); } -} // namespace native -} // namespace at +} // namespace at::native diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h b/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h new file mode 100644 index 00000000000000..311231432c7cfe --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h @@ -0,0 +1,74 @@ +#include + +#include + +#include +#include + +#include + +namespace pytorch_flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Alibi { + + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope) + , max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + + template + __forceinline__ __device__ void apply_alibi(Tensor &tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } + } + +}; + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h b/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h index 3e05d7e7195e8c..bbaf6978002177 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h @@ -24,12 +24,12 @@ struct BlockInfo { } template - inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; } template - inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; } diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h b/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h new file mode 100644 index 00000000000000..8dc4b0b22bcc9d --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h @@ -0,0 +1,96 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +namespace pytorch_flash { + +using namespace cute; + +struct Dropout { + + const unsigned long long seed, offset; + const uint8_t p_dropout_in_uint8_t; + + __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, + const uint8_t p_dropout_in_uint8_t, + const int bid, const int hid, const int tid, const int nheads) + : seed(seed) + , offset(offset + (bid * nheads + hid) * 32 + tid % 32) + , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { + } + + template + __forceinline__ __device__ void apply_dropout(Tensor &tensor_, + int block_row_start, int block_col_start, int block_row_stride) { + // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) + Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout())); + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same::value || std::is_same::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } + } + +}; + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h index 23fa6584b9b564..9ce14cf6489ef2 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash.h @@ -5,13 +5,15 @@ #pragma once #include -#include - -#include - -namespace pytorch_flash{ +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif +#include // For at::cuda::philox::unpack +namespace pytorch_flash { constexpr int TOTAL_DIM = 0; constexpr int H_DIM = 1; constexpr int D_DIM = 2; @@ -19,7 +21,7 @@ constexpr int D_DIM = 2; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Qkv_params { - using index_t = uint32_t; + using index_t = int64_t; // The QKV matrices. void *__restrict__ q_ptr; void *__restrict__ k_ptr; @@ -96,7 +98,12 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ rotary_sin_ptr; // The indices to index into the KV cache. - int *__restrict__ cache_batch_idx; + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; // The dropout probability (probability of keeping an activation). float p_dropout; @@ -126,6 +133,9 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved; int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -165,6 +175,9 @@ struct Flash_bwd_params : public Flash_fwd_params { // The pointer to the softmax d sum. void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -172,7 +185,6 @@ struct Flash_bwd_params : public Flash_fwd_params { template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure); - +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 07c9f7e547facd..8f6f7a9f357dc9 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -1,29 +1,5 @@ /****************************************************************************** - * Copyright (c) 2022, Tri Dao. - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS @@ -50,6 +26,7 @@ #include #include #include +#include #endif @@ -93,11 +70,11 @@ void set_params_fprop(Flash_fwd_params ¶ms, float p_dropout, float softmax_scale, int window_size_left, - int window_size_right) { + int window_size_right, + bool seqlenq_ngroups_swapped=false) { - // Reset the parameters should be equivalent + // Reset the parameters params = {}; - // memset(¶ms, 0, sizeof(params)); params.is_bf16 = q.dtype() == at::kBFloat16; @@ -121,6 +98,10 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.k_batch_stride = k.stride(0); params.v_batch_stride = v.stride(0); params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } } params.cu_seqlens_q = static_cast(cu_seqlens_q_d); @@ -159,6 +140,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.rp_dropout = 1.f / params.p_dropout; params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. @@ -169,7 +153,16 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.window_size_left = window_size_left; params.window_size_right = window_size_right; + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -202,7 +195,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float p_dropout, float softmax_scale, int window_size_left, - int window_size_right) { + int window_size_right, + bool deterministic) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -244,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms, // Softmax sum params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; } void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { - FWD_HEADDIM_SWITCH(params.d, [&] { + HEADDIM_SWITCH(params.d, [&] { if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 run_mha_fwd_(params, stream); } else { @@ -300,16 +296,62 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n return 1; } +void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, + const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q, + const int head_size_rounded, const float p_dropout, + const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) { + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + params.num_splits = num_splits; + if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout + if (num_splits < 1) { + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + } + if (params.num_splits > 1) { + at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + } +} + +void set_params_alibi(Flash_fwd_params ¶ms, c10::optional &alibi_slopes_, int batch_size, int num_heads){ +#ifdef FLASHATTENTION_DISABLE_ALIBI + TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi."); + params.alibi_slopes_ptr = nullptr; +#else + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32"); + CHECK_DEVICE(alibi_slopes); + TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension"); + TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads})); + params.alibi_slopes_ptr = alibi_slopes.data_ptr(); + params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + } else { + params.alibi_slopes_ptr = nullptr; + } +#endif +} + // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; std::tuple mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -350,12 +392,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); at::Tensor temp_q = q; if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; @@ -369,9 +415,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); at::Tensor q_padded, k_padded, v_padded; - q_padded = temp_q; - k_padded = k; - v_padded = v; + q_padded = temp_q; + k_padded = k; + v_padded = v; at::Tensor out; if (out_.has_value()) { @@ -423,30 +469,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head window_size_left, window_size_right); - // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); - const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; - // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. - // In any case we don't expect seqlen_q to be larger than 64 for inference. - const int num_m_blocks = (seqlen_q + 64 - 1) / 64; - params.num_splits = 1; - if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); - if (params.num_splits > 1) { - at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); - } + + set_params_splitkv(params, batch_size, num_heads, + head_size, seqlen_k, seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function - auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); at::Tensor seed_t, offset_t; if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. @@ -476,6 +509,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head } + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); @@ -501,18 +536,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - const int max_seqlen_q, + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - const bool is_causal, - const int window_size_left, + bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { - if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; @@ -544,17 +579,39 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const auto sizes = q.sizes(); - const int total_q = sizes[0]; const int batch_size = cu_seqlens_q.numel() - 1; - const int num_heads = sizes[1]; + int num_heads = sizes[1]; const int head_size_og = sizes[2]; const int total_k = k.size(0); const int num_heads_k = k.size(1); + + if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case + if (is_causal) { window_size_right = 0; } + + void *cu_seqlens_q_d = cu_seqlens_q.data_ptr(); + + // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case + // H/t Daniel Haziza + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + at::Tensor temp_q = q; + if (seqlenq_ngroups_swapped) { + const int ngroups = num_heads / num_heads_k; + temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); + max_seqlen_q = ngroups; + num_heads = num_heads_k; + cu_seqlens_q_d = nullptr; + } + + const int total_q = q.sizes()[0]; + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!") + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, total_q, num_heads, head_size_og); CHECK_SHAPE(k, total_k, num_heads_k, head_size_og); CHECK_SHAPE(v, total_k, num_heads_k, head_size_og); @@ -569,7 +626,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } at::Tensor q_padded, k_padded, v_padded; - q_padded = q; + q_padded = temp_q; k_padded = k; v_padded = v; @@ -619,7 +676,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q num_heads, num_heads_k, head_size, head_size_rounded, q_padded, k_padded, v_padded, out, - cu_seqlens_q.data_ptr(), + cu_seqlens_q_d, cu_seqlens_k.data_ptr(), seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, @@ -627,9 +684,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + seqlenq_ngroups_swapped); + if (seqlenq_ngroups_swapped) { + // Only apply split-k for decoding + set_params_splitkv(params, batch_size, num_heads, + head_size, max_seqlen_k, max_seqlen_q, + head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); + } - // We want to checkpoint and save the RNG state for backward if dropout + // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function auto gen = at::get_generator_or_default(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); @@ -664,31 +728,33 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q } - auto stream = at::cuda::getCurrentCUDAStream().stream(); - run_mha_fwd(params, stream); + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + if (seqlenq_ngroups_swapped) { + std::array size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og}; + std::array size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og}; + out = out.reshape(size_before).transpose(1, 2).reshape(size_after); + q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); + softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); + } return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; } -void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { FP16_SWITCH(!params.is_bf16, [&] { - if (params.d <= 32) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 64) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 96) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 128) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 160) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 192) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 224) { - run_mha_bwd_(params, stream, configure); - } else if (params.d <= 256) { - run_mha_bwd_(params, stream, configure); - } + HEADDIM_SWITCH(params.d, [&] { + run_mha_bwd_(params, stream); + }); }); } @@ -702,14 +768,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -756,8 +827,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192) { - TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); + if (head_size > 192 && (head_size <= 224 || is_dropout)) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -768,6 +839,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); @@ -803,8 +877,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si dv = at::empty_like(v); } - // const at::Tensor& dout_padded = dout; - // bool loop = seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; @@ -818,9 +890,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dq_accum; at::Tensor dk_accum, dv_accum; if (loop) { - dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); - // dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); - // dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + if (!deterministic) { + dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } + // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -854,10 +931,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + deterministic); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; - // launch(params, stream, /*configure=*/true); at::PhiloxCudaState philox_args; if (is_dropout) { @@ -872,12 +950,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si } params.philox_args = philox_args; + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + if (seqlen_q > 0) { - launch(params, stream, /*configure=*/false); + launch(params, stream); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. - dk.zero_(); - dv.zero_(); + dk_expanded.zero_(); + dv_expanded.zero_(); softmax_d.zero_(); } @@ -901,17 +981,24 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + if (is_causal) { window_size_right = 0; } auto dprops = at::cuda::getCurrentDeviceProperties(); // bool is_sm75 = dprops->major == 7 && dprops->minor == 5; @@ -925,7 +1012,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size auto stream = at::cuda::getCurrentCUDAStream().stream(); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == at::kHalf|| q_dtype == at::kBFloat16, + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, "FlashAttention only support fp16 and bf16 data type"); if (q_dtype == at::kBFloat16) { TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer"); @@ -962,8 +1049,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!"); TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); - if (head_size > 192) { - TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800"); + if (head_size > 192 && (head_size <= 224 || is_dropout)) { + TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800"); } TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -974,6 +1061,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + if (window_size_left >= max_seqlen_k) { window_size_left = -1; } + if (window_size_right >= max_seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, total_q, num_heads, head_size); CHECK_SHAPE(k, total_k, num_heads_k, head_size); CHECK_SHAPE(v, total_k, num_heads_k, head_size); @@ -1008,11 +1098,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, total_k, num_heads_k, head_size); } else { - dv = at::empty_like(k); + dv = at::empty_like(v); } - // const at::Tensor& dout_padded = dout; - // bool loop = max_seqlen_k > blocksize_c; // TODO: change later, for now set to true for simplicity bool loop = true; @@ -1033,7 +1121,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. - dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + if (!deterministic) { + dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } else { + const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads); + dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + } } at::Tensor dk_expanded, dv_expanded; @@ -1072,10 +1165,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + deterministic); + params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; - // launch(params, stream, /*configure=*/true); at::PhiloxCudaState philox_args; if (is_dropout) { @@ -1090,7 +1184,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size } params.philox_args = philox_args; - launch(params, stream, /*configure=*/false); + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); + + if (max_seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } // For MQA/GQA we need to sum dK and dV across the groups if (num_heads_k != num_heads) { @@ -1103,18 +1206,20 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size std::tuple mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size - const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size - const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size + const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. + const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table. c10::optional &k_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &v_, // batch_size x seqlen_knew x num_heads_k x head_size c10::optional &seqlens_k_, // batch_size c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &block_table_, // batch_size x max_num_blocks_per_seq + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits @@ -1143,25 +1248,41 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + at::Tensor block_table; + const bool paged_KV = block_table_.has_value(); + if (paged_KV) { + TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx"); + block_table = block_table_.value(); + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + } + const auto sizes = q.sizes(); const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; const int head_size_og = sizes[3]; - const int seqlen_k = kcache.size(1); + + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); + const int num_blocks = !paged_KV ? 0 : kcache.size(0); + const int page_block_size = !paged_KV ? 1 : kcache.size(1); + TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256"); + const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); - const int batch_size_c = kcache.size(0); - TORCH_CHECK(batch_size > 0, "batch size must be positive"); + const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; + TORCH_CHECK(batch_size > 0, "batch size must be postive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case + // causal=true is the same as causal=false in this case + if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } if (is_causal) { window_size_right = 0; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); if (seqlenq_ngroups_swapped) { const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); @@ -1169,9 +1290,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he num_heads = num_heads_k; } + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); - CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); - CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + if (!paged_KV) { + CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og); + } else { + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og); + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + } at::Tensor q_padded, kcache_padded, vcache_padded; if (head_size_og % 8 != 0) { @@ -1310,27 +1440,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32"); params.cache_batch_idx = reinterpret_cast(cache_batch_idx.data_ptr()); } - // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); - const int num_n_blocks = (seqlen_k + block_n - 1) / block_n; - // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. - // In any case we don't expect seqlen_q to be larger than 64 for inference. - const int num_m_blocks = (seqlen_q + 64 - 1) / 64; - params.num_splits = num_splits; - if (num_splits < 1) { - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); - } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); - if (params.num_splits > 1) { - at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); - at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); + + set_params_splitkv(params, batch_size, num_heads, + head_size, seqlen_k, seqlen_q, + head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts); + + if (paged_KV) { + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); } + params.page_block_size = page_block_size; + + + set_params_alibi(params, alibi_slopes_, batch_size, num_heads); auto stream = at::cuda::getCurrentCUDAStream().stream(); - // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx - run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value()); + // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx, + // or paged KV cache + run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV); if (head_size_og % 8 != 0) { // out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)}); @@ -1352,6 +1479,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he } return {out, softmax_lse}; } + } // namespace pytorch_fmha #endif diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h index fd15d929e300be..2745b28dca29b8 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -12,10 +12,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); @@ -28,13 +29,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - const int max_seqlen_q, + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - const bool is_causal, - const int window_size_left, + bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_); @@ -50,11 +52,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); @@ -70,14 +74,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h index 9f2dc5ac388d1f..db817a0657ffcb 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h @@ -1,24 +1,23 @@ /*************************************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include - #include -#include #include #include #include -#include #include #include #include #include -#include +#include +#include +#include namespace pytorch_flash { @@ -66,7 +65,8 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; // Divide by 2 because right now we always use 2 for the ValLayout - constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; + constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; + constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; auto t = make_tile(make_layout(Int{}), Layout, Int, _2>, // (8, 2, 2) or (8, 4, 2) Stride<_1, Int, _8> >{}); // (1, 64, 8) or (1, 32, 8) @@ -76,359 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, - Tensor &dP_sum, const int gdP_col_stride, const float scale) { - static_assert(Layout0::rank == 3, "Only support 3D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); - // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) - // The last coordinate is the "page". - Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), - make_layout(get<0>(do_.layout()), - get<2>(do_.layout())))); - Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); - Tensor do_fp32 = pytorch_flash::convert_type(do_reshaped); - Tensor o_fp32 = pytorch_flash::convert_type(o_reshaped); - #pragma unroll - for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { - float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); - #pragma unroll - for (int ni = 1; ni < size<1>(do_reshaped); ni++) { - dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); - } - pytorch_flash::SumOp sum_op; - dP_sum_cur = pytorch_flash::Allreduce::run(dP_sum_cur, sum_op) * scale; - if (threadIdx.x % THREADS_PER_ROW == 0) { - dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void compute_dot_do_o(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; - - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), - Shape>{}, Stride<_1>{}); - - typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - // TODO: careful, we're zeroing out dQaccum with type float4, but when - // we do atomicAdds, we use type float. The layouts are different. Check this. - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); - - Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); - - // Allocate predicate tensors for k - Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); - // Set predicates for k bounds - #pragma unroll - for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM - ); - // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final - // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, - // so that (dP - dP_sum) is on the same scale. - dot_do_o(tdOrdO, tdOrO, dP_sum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); - if (Clear_dQaccum) { - // We're actually not zero'ing out all of dQaccum, but only the part that we're going to - // do atomicAdds on. - Tensor zero = make_fragment_like(tdQgdQaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void clear_dKVaccum(const Params ¶ms) { - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; - - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, Stride, _1>{}); - - typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); - Tensor zero = make_fragment_like(tdKgdKaccum); - clear(zero); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dQ from dQaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_k. -template -inline __device__ void convert_dQ(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q) return; - - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), - Shape, Int>{}, - make_stride(params.h * params.d_rounded, _1{})); - - Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdQ{}); - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; - auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); - - Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); - cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { - acc_dq(i) = tdQrdQaccum(i) * params.scale_softmax_rp_dropout; - } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = pytorch_flash::convert_type(acc_dq); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - __syncthreads(); - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. -// This is used in the case where we want to parallelize the backward across seqlen_q. -template -inline __device__ void convert_dKV(const Params ¶ms) { - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - const int n_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - - const BlockInfo binfo(params, bidb); - if (n_block * kBlockN >= binfo.actual_seqlen_k) return; - - const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) - + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; - const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) - + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; - const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded - + n_block * kBlockN) * params.d_rounded; - - Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), - Shape, Int>{}, - make_stride(params.dk_row_stride, _1{})); - Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), - Shape, Int>{}, - make_stride(params.dv_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - - Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutdKV{}); - Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; - auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); - auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); - Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); - Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); - - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); - CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); - - Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); - Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); - cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { - acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; - } - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { - acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; - } - // Convert acc_dk from fp32 to fp16 - Tensor rdK = pytorch_flash::convert_type(acc_dk); - Tensor rdV = pytorch_flash::convert_type(acc_dv); - Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); - cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); - __syncthreads(); - Tensor tdKrdK = make_tensor(shape(tdKgdK)); - Tensor tdVrdV = make_tensor(shape(tdVgdV)); - cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); - cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); - - Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); - #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -444,8 +92,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in constexpr int kBlockM = Kernel_traits::kBlockM; constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kHeadDim = Kernel_traits::kHeadDim; - // constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; + constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; + constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; const BlockInfo binfo(params, bidb); @@ -469,7 +117,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) - + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded + // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM; const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded @@ -718,7 +368,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in tdKsQt.data() = tdKsQt.data() + size(sQ); } - if (!Is_first && !Seq_parallel) { __syncthreads(); } + if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); } if (Kernel_traits::Is_V_in_regs) { // Clear the smem tiles to account for predicated off loads @@ -756,8 +406,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in #pragma unroll for (int mi = 0; mi < size(lse); ++mi) { const int row = get<0>(taccScS_row(mi)); - lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; + lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. // Tensor tKrK = make_fragment_like(tKsK); // // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK); @@ -791,11 +445,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in auto seeds = at::cuda::philox::unpack(params.philox_args); unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + unsigned long long offset = std::get<1>(seeds); + pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); clear(acc_dv); clear(acc_dk); + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + pytorch_flash::Alibi alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q); + for (; m_block >= m_block_min; --m_block) { Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) clear(acc_s); @@ -819,6 +478,12 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } + + if (Has_alibi) { + alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, + m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); + } + // TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond // actual_seqlen_k, because acc_s would be some finite value for those indices. // In the end when we multiply with K to get dQ, the corresponding values of K would be 0, @@ -855,28 +520,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } + // if (cute::thread(32, 0)) { print(scores); } // Compute the exponential value. pytorch_flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - if (Is_dropout) { + if constexpr (Is_dropout) { int warp_id = tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - Tensor scores_dropped = make_tensor(scores.data(), pytorch_flash::convert_layout_rowcol_Aregs(scores.layout())); - pytorch_flash::apply_dropout( - scores_dropped, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, AtomLayoutMS + dropout.template apply_dropout( + acc_s, block_row_idx, block_col_idx, AtomLayoutMS ); } // Convert scores from fp32 to fp16/bf16 Tensor rP = !Is_dropout - ? pytorch_flash::convert_type(scores) - : pytorch_flash::convert_type_relu(scores); - // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + ? pytorch_flash::convert_type(acc_s) + : pytorch_flash::convert_type_relu(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2) + // if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8. + Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); // if (cute::thread0()) { print(tPaP); } @@ -889,7 +553,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA clear(acc_dp); - // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), flash::convert_layout_acc_rowcol(acc_dp.layout())); + // Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout())); // #pragma unroll // for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) { // #pragma unroll @@ -953,9 +617,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Layout p_l = tPrP.layout(); // Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l))); - // flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); + // pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); - // flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); + // pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); // if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); } @@ -1120,430 +784,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { - - using Element = typename Kernel_traits::Element; - using ElementAccum = typename Kernel_traits::ElementAccum; - using index_t = typename Kernel_traits::index_t; - - // Shared memory. - extern __shared__ char smem_[]; - - // The thread index. - const int tidx = threadIdx.x; - - constexpr int kBlockM = Kernel_traits::kBlockM; - constexpr int kBlockN = Kernel_traits::kBlockN; - constexpr int kHeadDim = Kernel_traits::kHeadDim; - // constexpr int kNWarps = Kernel_traits::kNWarps; - constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; - constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP; - - const BlockInfo binfo(params, bidb); - if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; - - int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); - if (Is_causal) { - n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); - } - - // We iterate over the blocks in reverse order. This is because the last block is the only one - // that needs masking when we read K and V from global memory. Moreover, iterating in reverse - // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) - + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - // We'll advance gdKaccum and gdVaccum before the first write. - const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded - + n_block_max * kBlockN) * params.d_rounded; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - - // We assume that params.d == kHeadDim for now - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); - Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), - Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), - Shape, Int>{}, - Stride, _1>{}); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); - - Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), - typename Kernel_traits::SmemLayoutQdO{}); - Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sdO = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutQdO{}); - Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{}); - Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(), - typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{}); - Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{}); - // Double buffer for sK - Tensor sV = make_tensor(sK.data() + 2 * size(sK), typename Kernel_traits::SmemLayoutKV{}); - Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{}); - Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{}); - Tensor sdS = make_tensor(sV.data() + size(sV), typename Kernel_traits::SmemLayoutPdS{}); - Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{}); - Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{}); - Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{}); - Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast(sdS.data().get())), - Shape>{}); - - typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; - auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; - auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; - auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); - - Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); - Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); - Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO); - Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) - Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) - Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); - Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); - - typename Kernel_traits::TiledMmaSdP tiled_mma_sdp; - auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx); - Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K) - Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) - Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K) - Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K) - - typename Kernel_traits::TiledMmadKV tiled_mma_dkv; - auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx); - Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N) - Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N) - Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N) - Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N) - - typename Kernel_traits::TiledMmadQ tiled_mma_dq; - auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx); - Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N) - Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N) - - Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_M_SdP, MMA_K - - // - // Copy Atom retiling - // - - auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); - auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx); - Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ); - Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO); - - auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp); - auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx); - Tensor tSsK = smem_thr_copy_KV.partition_S(sK); - Tensor tdPsV = smem_thr_copy_KV.partition_S(sV); - - // Partition sP and sdS to match the accumulator partitioning - // This has to be tiled_mma_sdp, not tiled_mma_dkv - auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp); - auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx); - Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) - Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); - auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx); - Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt); - Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt); - - auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv); - auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx); - Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt); - Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt); - - auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq); - auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx); - Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS); - - auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq); - auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx); - Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt); - - // - // PREDICATES - // - - // Construct identity layout for sQ and sK - Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - // Repeat the partitioning with identity layouts - Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - - // Allocate predicate tensors for k - Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); - - // Set predicates for k bounds - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } - #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } - } - - // Prologue - - Tensor tdOrdO = make_fragment_like(tdOgdO); - Tensor tdOrO = make_fragment_like(tdOgO); - - // TODO: Might need to exit early and write 0 to gdQ. - - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - pytorch_flash::copy( - gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - - Tensor tQrQ = make_fragment_like(tQgQ); - pytorch_flash::copy( - gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM - ); - - int n_block = n_block_max - 1; - if (n_block % 2 == 1) { - tKsK.data() = tKsK.data() + size(sK); - tSsK.data() = tSsK.data() + size(sK); - tdQsKt.data() = tdQsKt.data() + size(sK); - } - - pytorch_flash::copy( - gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - pytorch_flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN - ); - - Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N) - static_assert(decltype(size<0>(taccScS))::value == 4); - // Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices. - Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); - Tensor lse = make_tensor(Shape>{}); - #pragma unroll - for (int mi = 0; mi < size(lse); ++mi) { - const int row = get<0>(taccScS_row(mi)); - lse(mi) = row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0; - } - - cute::cp_async_fence(); - - Tensor dP_sum = make_fragment_like(lse); - cute::copy(tdOrdO, tdOsdO); - dot_do_o( - tdOrdO, tdOrO, sdPsum, - Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout - ); - __syncthreads(); - #pragma unroll - for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<0>(taccScS_row(mi))); } - - auto seeds = at::cuda::philox::unpack(params.philox_args); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; - - clear(acc_dq); - - for (; n_block >= 0; --n_block) { - Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_M_SdP, MMA_N) - clear(acc_s); - pytorch_flash::cp_async_wait<0>(); - __syncthreads(); - - pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); - - // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - // We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would - // be some finite value for those indices. In the end when we multiply with K to get dQ, - // the corresponding values of K would be 0, so the result would still be correct. - if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) { - pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, - binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)), - // binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, - AtomLayoutMS * 16); - } - // Compute the exponential value. - pytorch_flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); - if (Is_dropout) { - int warp_id = tidx / 32; - int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS; - // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 - static_assert(MMA_N_SdP % 2 == 0); - int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - Tensor scores_dropped = make_tensor(scores.data(), pytorch_flash::convert_layout_rowcol_Aregs(scores.layout())); - pytorch_flash::apply_dropout( - scores_dropped, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, AtomLayoutMS - ); - } - // Convert scores from fp32 to fp16/bf16 - Tensor rP = !Is_dropout - ? pytorch_flash::convert_type(scores) - : pytorch_flash::convert_type_relu(scores); - // Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8. - Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); - Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); - - Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape, Int>{}); // (MMA=4, MMA_N, MMA_N) - CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA - CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA - - clear(acc_dp); - pytorch_flash::gemm(acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp, - smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); - - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) - Tensor dS = make_tensor(acc_dp.data(), scores.layout()); - auto pointwise_mult = [](float p, float dp, float d) { - return p * (!Is_dropout || p >= 0 ? dp - d : d); - }; - #pragma unroll - for (int mi = 0; mi < size<0>(dS); ++mi) { - #pragma unroll - for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); - } - } - - Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout()); - // Convert dS from fp32 to fp16 - Tensor tdSrdS = pytorch_flash::convert_type(dS_reshaped); - Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N) - cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); - __syncthreads(); - - if (n_block > 0) { - // Double buffer for sK - const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); - tKsK.data() = tKsK.data() + sK_offset; - tSsK.data() = tSsK.data() + sK_offset; - // Advance gK, gV - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); - pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); - // This cp_async_fence needs to be in the if block, otherwise the synchronization - // isn't right and we get race conditions. - cute::cp_async_fence(); - } - - Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - clear(acc_dv); - pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); - // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(acc_dv); } - tdVgdVaccum.data() = tdVgdVaccum.data() + (-int(kBlockN * params.d_rounded)); - #pragma unroll - for (int i = 0; i < size(acc_dv); ++i) { atomicAdd(&tdVgdVaccum(i), acc_dv(i)); } - - __syncthreads(); - Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K - clear(acc_dk); - pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv, - smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt); - tdKgdKaccum.data() = tdKgdKaccum.data() + (-int(kBlockN * params.d_rounded)); - #pragma unroll - for (int i = 0; i < size(acc_dk); ++i) { atomicAdd(&tdKgdKaccum(i), acc_dk(i)); } - - pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq, - smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt); - // Double buffer for sK - tdQsKt.data() = tdQsKt.data() + (n_block % 2 == 0 ? size(sK) : -size(sK)); - - } - - // Epilogue - - #pragma unroll - for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } - // Convert acc_dq from fp32 to fp16 - Tensor rdQ = pytorch_flash::convert_type(acc_dq); - - Tensor sdQ = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutdQ{}); - - // Partition sdV and sdK to match the accumulator partitioning - auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); - auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) - Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) - - __syncthreads(); - cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); - - const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) - + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; - Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), - Shape, Int>{}, - make_stride(params.dq_row_stride, _1{})); - - typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; - auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); - Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) - Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); - - __syncthreads(); - - Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); - cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); - - Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); - Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); - if (!Is_even_K) { - #pragma unroll - for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } - } - // Clear_OOB_K must be false since we don't want to write zeros to gmem - pytorch_flash::copy( - gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM - ); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_dq_dk_dv(const Params ¶ms) { // The block index for the batch. @@ -1557,44 +798,32 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; if (n_block_max == 1) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } else { // Iterating backward from n_block_max - 1 to 0 might save 1 register - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block_max - 1); for (int n_block = n_block_max - 2; n_block > 0; n_block--) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } - compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); + compute_dq_dk_dv_1colblock(params, bidb, bidh, 0); } } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { - const int n_block = blockIdx.x; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.z; - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void compute_dq_dk_dv_seqq_parallel(const Params ¶ms) { - - const int m_block = blockIdx.x; - // The block index for the batch. - const int bidb = blockIdx.y; - // The block index for the head. - const int bidh = blockIdx.z; - - compute_dq_dk_dv_1rowblock(params, bidb, bidh, m_block); + // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace pytorch_flash +} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h index 5c65bbd5ced150..8644ccd88a69ce 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h @@ -1,4 +1,6 @@ -// Copyright (c) 2022, Tri Dao. +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ #pragma once @@ -6,58 +8,81 @@ #include #include +#include #include namespace pytorch_flash { -template -__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) { - pytorch_flash::compute_dot_do_o(params); -} +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif -template -__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) { - pytorch_flash::clear_dKVaccum(params); -} +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); -template -__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) { - pytorch_flash::compute_dq_dk_dv(params); +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params) + +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) { + #if defined(ARCH_SUPPORTS_FLASH) + pytorch_flash::compute_dq_dk_dv(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif } -template -__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { + #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - pytorch_flash::compute_dq_dk_dv_seqk_parallel(params); + pytorch_flash::compute_dq_dk_dv_seqk_parallel(params); #else - printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); + FLASH_UNSUPPORTED_ARCH #endif } -template -__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) { - pytorch_flash::compute_dq_dk_dv_seqq_parallel(params); +template +__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { + pytorch_flash::compute_dot_do_o(params); } template -__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) { - pytorch_flash::convert_dQ(params); +__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) { + pytorch_flash::clear_dKVaccum(params); } template -__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) { +__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { + pytorch_flash::convert_dQ(params, nsplits); +} + +template +__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { pytorch_flash::convert_dKV(params); } template -void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) { const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; dim3 grid_m(num_m_block, params.b, params.h); const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - dim3 grid_n(num_n_block, params.b, params.h); + int gridDimx = num_n_block; + if (params.deterministic) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h); + } + dim3 grid_n(gridDimx, params.b, params.h); - flash_bwd_dot_do_o_kernel<<>>(params); + if (!params.deterministic) { + flash_bwd_dot_do_o_kernel<<>>(params); + } else { + flash_bwd_dot_do_o_kernel<<>>(params); + } C10_CUDA_KERNEL_LAUNCH_CHECK(); // We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not @@ -66,21 +91,23 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_even_K = params.d == Kernel_traits::kHeadDim; constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -91,58 +118,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, C10_CUDA_CHECK(cudaFuncSetAttribute( kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize)); } - kernel_dq<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template -void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; - dim3 grid_n(num_n_block, params.b, params.h_k); - flash_bwd_clear_dkvaccum_kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; - dim3 grid_m(num_m_block, params.b, params.h); - // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check - // for cu_seqlens_k as well. - const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0; - const bool is_even_K = params.d == Kernel_traits::kHeadDim; - constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock; - // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - - auto kernel_dkv = &flash_bwd_convert_dkv_kernel; - if (Kernel_traits::kSmemKVSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize)); - } - kernel_dkv<<>>(params); + kernel_dq<<>>(params, !params.deterministic ? 1 : gridDimx); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template -void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - if (configure) return; - run_flash_bwd_seqk_parallel(params, stream, configure); +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { +#ifndef FLASHATTENTION_DISABLE_BACKWARD + run_flash_bwd_seqk_parallel(params, stream); +#endif } template -void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; int device; cudaGetDevice(&device); @@ -152,21 +140,21 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } } else { // 96 KB - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; int device; cudaGetDevice(&device); @@ -177,42 +165,41 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { // Changing AtomLayoutMdQ from 2 to 4 takes the same time - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // This is slightly faster. We want to split M more so we need fewer registers to store LSE. if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); // This has a lot of register spilling - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream); } else { // if (params.h == params.h_k) { - // run_flash_bwd, Is_dropout>(params, stream, configure); - run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); // } else { - // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); // } } }); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); // M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream, configure); + // run_flash_bwd>(params, stream); } template -void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; int device; cudaGetDevice(&device); @@ -223,26 +210,22 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // if (params.h == params.h_k) { - if (max_smem_per_block >= 116 * 1024) { - if constexpr(!Is_dropout) { // 92KB - run_flash_bwd, Is_dropout>(params, stream, configure); - } else { // 116 KB - // This is faster for dropout since we don't have many registers to spare - run_flash_bwd, Is_dropout>(params, stream, configure); - } - } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if (max_smem_per_block >= 116 * 1024) { + if constexpr(!Is_dropout) { // 92KB + run_flash_bwd, Is_dropout>(params, stream); + } else { // 116 KB + // This is faster for dropout since we don't have many registers to spare + run_flash_bwd, Is_dropout>(params, stream); } - // } else { - // run_flash_bwd_seqq_parallel>(params, stream, configure); - // } + } else { + run_flash_bwd, Is_dropout>(params, stream); + } }); } template -void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; int device; cudaGetDevice(&device); @@ -253,35 +236,30 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // if (params.h == params.h_k) { - // run_flash_bwd>(params, stream, configure); - // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). - // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. - // run_flash_bwd>(params, stream, configure); - if (max_smem_per_block >= 144 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); - // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream, configure); - // run_flash_bwd_seqq_parallel, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - // run_flash_bwd, Is_dropout>(params, stream, configure); - } else { - // run_flash_bwd, Is_dropout>(params, stream, configure); - run_flash_bwd, Is_dropout>(params, stream, configure); - } - // run_flash_bwd>(params, stream, configure); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // run_flash_bwd>(params, stream); + // This is faster, in the case of sequence-parallel bwd (where we need fewer registers). + // Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why. + // run_flash_bwd>(params, stream); + if (max_smem_per_block >= 144 * 1024) { + run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd_seqk_parallel, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + // run_flash_bwd, Is_dropout>(params, stream); + } else { + // run_flash_bwd, Is_dropout>(params, stream); + run_flash_bwd, Is_dropout>(params, stream); + } + // run_flash_bwd>(params, stream); - // run_flash_bwd>(params, stream, configure); - // } else { - // run_flash_bwd_seqq_parallel>(params, stream, configure); - // } + // run_flash_bwd>(params, stream); }); } template -void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; int device; cudaGetDevice(&device); @@ -291,17 +269,17 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 116 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; int device; cudaGetDevice(&device); @@ -311,25 +289,25 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 136 * 1024) { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } else { - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); } }); } template -void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 224; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_bwd, Is_dropout>(params, stream, configure); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_bwd, Is_dropout>(params, stream); }); } template -void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; cudaGetDevice(&device); @@ -339,14 +317,18 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo if (status_ != cudaSuccess) { C10_CUDA_CHECK(status_); } - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { if (max_smem_per_block >= 176 * 1024) { // H100 - run_flash_bwd, Is_dropout>(params, stream, configure); - } else { // A100, we don't do double buffering to save smem - run_flash_bwd, Is_dropout>(params, stream, configure); + run_flash_bwd, Is_dropout>(params, stream); + } else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem + run_flash_bwd, Is_dropout>(params, stream); + } else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering. + if constexpr (!Is_dropout) { + run_flash_bwd, false>(params, stream); + } } }); } -}; // namespace pytorch_fmha +}; // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h new file mode 100644 index 00000000000000..7811984b7e61e0 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include + +#include +#include +#include + +namespace pytorch_flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = pytorch_flash::convert_type(do_reshaped); + Tensor o_fp32 = pytorch_flash::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + pytorch_flash::SumOp sum_op; + dP_sum_cur = pytorch_flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + pytorch_flash::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, + Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum)); + + Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum); + clear(acc_dq); + for (int s = 0; s < nsplits; ++s) { + cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum); + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); } + tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride; + } + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = pytorch_flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + pytorch_flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = pytorch_flash::convert_type(acc_dk); + Tensor rdV = pytorch_flash::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + pytorch_flash::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + pytorch_flash::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +} // namespace flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h index 844ba52a211a47..0386a07cc64fd6 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h @@ -1,23 +1,23 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once -#include #include -#include #include #include #include -#include + #include #include #include #include -#include +#include +#include +#include namespace pytorch_flash { @@ -25,57 +25,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, - Tensor2 &acc_o, float softmax_scale_log2) { - if (Is_first) { - pytorch_flash::template reduce_max(scores, scores_max); - pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - pytorch_flash::reduce_sum(scores, scores_sum); - } else { - Tensor scores_max_prev = make_fragment_like(scores_max); - cute::copy(scores_max, scores_max_prev); - pytorch_flash::template reduce_max(scores, scores_max); - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - #pragma unroll - for (int mi = 0; mi < size(scores_max); ++mi) { - float scores_max_cur = !Check_inf - ? scores_max(mi) - : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); - float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); - scores_sum(mi) *= scores_scale; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } - } - pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); - Tensor scores_sum_cur = make_fragment_like(scores_sum); - pytorch_flash::reduce_sum(scores, scores_sum_cur); - #pragma unroll - for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void write_softmax_to_gmem( - Tensor const &tOrP, Tensor &tPgP, TiledCopy gmem_tiled_copy_P -) { - // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) - Layout l = tOrP.layout(); - Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); - CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); - CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP)); - #pragma unroll - for (int mi = 0; mi < size<1>(tPrP); ++mi) { - cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -93,6 +43,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kNWarps = Kernel_traits::kNWarps; + auto seed_offset = at::cuda::philox::unpack(params.philox_args); + pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + + // Save seed and offset for backward. If we don't have this here, the 0-th thread block might + // exit early and no one saves the rng state. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + if (params.philox_args.captured_) { + *params.seed = std::get<0>(seed_offset); + *params.extragraph_offset = std::get<1>(seed_offset); + } + } + const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -108,15 +71,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - // Save seed and offset for backward. If we don't have this here, the 0-th thread block might - // exit early and no one saves the rng state. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - auto seeds = at::cuda::philox::unpack(params.philox_args); - if (params.philox_args.captured_) { - *params.seed = std::get<0>(seeds); - *params.extragraph_offset = std::get<1>(seeds); - } - } const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; @@ -191,8 +145,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); - typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; - auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); @@ -200,7 +152,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); - Tensor tPgP = gmem_thr_copy_P.partition_D(gP); typename Kernel_traits::TiledMma tiled_mma; auto thr_mma = tiled_mma.get_thread_slice(tidx); @@ -208,6 +159,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + Tensor tSgS = thr_mma.partition_C(gP); + Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // MMA, MMA_M, MMA_K // @@ -228,10 +181,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - // // PREDICATES // @@ -274,16 +223,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue - Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } - // // Copy rmem to smem - // // copy(tQrQ, tQsQ); - // pytorch_flash::cp_async_wait<0>(); - // __syncthreads(); // // if (cute::thread(1, 0)) { print(tQsQ); } // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // if (cute::thread0()) { print(sQNoSwizzle); } @@ -313,16 +257,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); } - auto seeds = at::cuda::philox::unpack(params.philox_args); - if (params.philox_args.captured_) { - *params.seed = std::get<0>(seeds); - *params.extragraph_offset = std::get<1>(seeds); - } + clear(acc_o); - unsigned long long seed = std::get<0>(seeds); - unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32; + pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; - clear(acc_o); + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. @@ -360,37 +300,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi ); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print_tensor(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } - } else { - // Tensor caccS = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) - // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) - // static_assert(decltype(size<0>(taccScS))::value == 4); - // // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices. - // Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0); - // Tensor idx_rowcol = make_tensor(taccScS.data(), pytorch_flash::convert_layout_acc_rowcol(taccScS.layout())); - // pytorch_flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM); - // Idk why it's get<1> and not get<0> of the stride. - // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } - // I can't get the stride from idx_row - pytorch_flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - // m_block * kBlockM + get<0>(idx_row(0)), - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - // m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16 - // m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16 - ); - // if (cute::thread0()) { print_tensor(scores); } - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); pytorch_flash::cp_async_wait<0>(); __syncthreads(); @@ -405,33 +317,31 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // TODO: when we have key_padding_mask we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); - - // Convert scores from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = pytorch_flash::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - cute::copy(tOrP, tOrP_copy); - pytorch_flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps ); - pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } - // if (cute::thread0()) { print(tOrP); } - pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } // This check is at the end of the loop since we always have at least 1 iteration @@ -468,58 +378,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - pytorch_flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = pytorch_flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor rP = pytorch_flash::convert_type(acc_s); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor tOrP_copy = make_fragment_like(tOrP); - cute::copy(tOrP, tOrP_copy); - pytorch_flash::apply_dropout( - tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout( + rP_drop, block_row_idx, block_col_idx, kNWarps ); - pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P); - tPgP.data() = tPgP.data() + (-kBlockN); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, - block_row_idx, block_col_idx, kNWarps); + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); } - pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); + pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - Tensor lse = make_fragment_like(scores_sum); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } - - // if (cute::thread0()) { print(acc_o_rowcol); } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = pytorch_flash::convert_type(acc_o); @@ -585,7 +474,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -673,10 +562,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -730,11 +626,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); - // TODO: this might need to change if we change the mma instruction in SM70 - Tensor scores_max = make_tensor(Shape(acc_o)>>{}); - Tensor scores_sum = make_fragment_like(scores_max); - - // // PREDICATES // @@ -814,11 +705,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { pytorch_flash::copy_w_min_idx( tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN ); - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); if (params.rotary_dim == 0) { pytorch_flash::copy_w_min_idx( @@ -844,19 +736,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } } - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } } // Need this before we can read in K again, so that we'll see the updated K values. __syncthreads(); - if (n_block_max > n_block_copy_min) { - tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride; - tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride; - } + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; } // Read Q from gmem to smem, optionally apply rotary embedding. - Tensor tQrQ = make_fragment_like(tQgQ); if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs pytorch_flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, @@ -907,6 +810,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons clear(acc_o); + pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + pytorch_flash::Mask mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + // For performance reason, we separate out two kinds of iterations: // those that need masking on S, and those that don't. // We need masking on S for the very last block when K and V has length not multiple of kBlockN. @@ -927,7 +835,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Advance gV if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads @@ -943,21 +859,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons ); // if (cute::thread0()) { print(acc_s); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - // if (cute::thread0()) { print(scores); } - // We don't put the masking before the matmul S = Q K^T because we don't clear sK - // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul - // can produce Inf / NaN. - if (!Is_causal && !Is_local) { - if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } - } else { - pytorch_flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); pytorch_flash::cp_async_wait<0>(); __syncthreads(); @@ -966,7 +870,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons if (n_block > n_block_min) { // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. @@ -975,18 +887,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // We have key_padding_mask so we'll need to Check_inf masking_step == 0 - ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) - : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + ? softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } - // Convert scores from fp32 to fp16/bf16 - Tensor rP = pytorch_flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = pytorch_flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); - // if (cute::thread0()) { print(scores); } + pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // This check is at the end of the loop since we always have at least 1 iteration if (n_masking_steps > 1 && n_block <= n_block_min) { @@ -1002,7 +913,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons pytorch_flash::cp_async_wait<0>(); __syncthreads(); // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } pytorch_flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); @@ -1015,50 +934,38 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons __syncthreads(); if (n_block > n_block_min) { // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } pytorch_flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); } - // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) - Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); - if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) { - pytorch_flash::apply_mask_local( - scores, n_block * kBlockN, binfo.actual_seqlen_k, - m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, - binfo.actual_seqlen_q, kNWarps * 16, - params.window_size_left, params.window_size_right - ); - } - softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + mask.template apply_mask( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o(acc_s, acc_o, params.scale_softmax_log2); - Tensor rP = pytorch_flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs(rP.layout())); + Tensor rP = pytorch_flash::convert_type(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs(rP.layout())); - pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } // Epilogue - // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) - Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); - // if (cute::thread0()) { print(acc_o_rowcol); } - Tensor lse = make_fragment_like(scores_sum); - #pragma unroll - for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = scores_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; - lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum); - float scale = inv_sum; - #pragma unroll - for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } - } + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } - // if (cute::thread0()) { print(acc_o_rowcol); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning @@ -1135,7 +1042,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1151,12 +1058,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - pytorch_flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + pytorch_flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1165,7 +1072,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - pytorch_flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + pytorch_flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1330,6 +1237,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { } } -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace flash +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h index d76eaa4450e4b8..fcc99686eb835c 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h @@ -12,27 +12,40 @@ namespace pytorch_flash { -template -__global__ void flash_fwd_kernel(Flash_fwd_params params) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - pytorch_flash::compute_attn(params); +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + pytorch_flash::compute_attn(params); #else - printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); + FLASH_UNSUPPORTED_ARCH #endif } -template -__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - pytorch_flash::compute_attn_splitkv(params); +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + pytorch_flash::compute_attn_splitkv(params); #else - printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!"); + FLASH_UNSUPPORTED_ARCH #endif } -template -__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { static_assert(Log_max_splits >= 1); pytorch_flash::combine_attn_seqk_parallel(params); } @@ -52,27 +65,30 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -90,22 +106,24 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_K = params.d == Kernel_traits::kHeadDim; BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -118,7 +136,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If headdim is divisible by 64, then we set kBlockM = 8, etc. constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); - BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { if (params.num_splits <= 2) { flash_fwd_splitkv_combine_kernel<<>>(params); } else if (params.num_splits <= 4) { @@ -152,7 +170,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); @@ -162,7 +180,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower @@ -186,7 +204,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), if (is_sm8x) { @@ -212,7 +230,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -249,7 +267,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, H100, 128 x 32 is the fastest. // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), @@ -277,7 +295,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if constexpr(!Is_dropout) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -305,7 +323,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_block = %d\n", max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB run_flash_fwd, Is_dropout, Is_causal>(params, stream); @@ -336,7 +354,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] { // For A100, we want to run with 128 x 64 (128KB smem). // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. @@ -353,4 +371,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -}; // namespace pytorch_fmha +}; // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h index 875701e6cf2be9..ef1c3b91c94b03 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits.h @@ -1,5 +1,5 @@ /****************************************************************************** - * Copyright (c) 2023, Tri Dao. + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once @@ -26,7 +26,7 @@ struct Flash_kernel_traits { #endif using ElementAccum = float; - using index_t = uint32_t; + using index_t = int64_t; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 using MMA_Atom_Arch = std::conditional_t< @@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base { SmemLayoutAtomQ{}, Shape, Int>{})); - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomVtransposedNoSwizzle{}, - Shape, Int>{})); - // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); using SmemLayoutAtomO = decltype( composition(Swizzle{}, @@ -116,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base { using SmemCopyAtomO = Copy_Atom; using SmemCopyAtomOaccum = Copy_Atom; - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); @@ -149,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base { make_tiled_copy(Copy_Atom{}, GmemLayoutAtom{}, Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store using GmemLayoutAtomOaccum = std::conditional_t< kBlockKSmem == 32, @@ -244,26 +223,18 @@ struct Flash_bwd_kernel_traits : public Base { SmemLayoutAtomKV{}, make_shape(Int{}, Int{}))); - using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomKtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); - using SmemLayoutKtransposed = decltype(tile_to_shape( - SmemLayoutAtomKtransposed{}, - make_shape(Int{}, Int{}))); - // Maybe the KtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomKtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); // TODO: generalize to other values of kBlockN // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // static constexpr int kPBlockN = kBlockN; - static_assert(kBlockN >= 64); + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. - static constexpr int kPBlockN = 64; + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); static constexpr int kSwizzlePdS = 3; @@ -274,30 +245,15 @@ struct Flash_bwd_kernel_traits : public Base { using SmemLayoutPdS = decltype(tile_to_shape( SmemLayoutAtomPdS{}, make_shape(Int{}, Int{}))); - using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomPdStransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); - using SmemLayoutPdStransposed = decltype(tile_to_shape( - SmemLayoutAtomPdStransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomPdStransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + using SmemCopyAtomPdS = Copy_Atom; - using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, - Stride<_1, Int>>; - using SmemLayoutAtomQdOtransposed = decltype( - composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); - using SmemLayoutQdOtransposed = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposed{}, - make_shape(Int{}, Int{}))); - using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( - SmemLayoutAtomQdOtransposedNoSwizzle{}, - make_shape(Int{}, Int{}))); - // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); using SmemLayoutAtomdKV = decltype( composition(Swizzle{}, @@ -317,16 +273,12 @@ struct Flash_bwd_kernel_traits : public Base { make_shape(Int{}, Int{}))); using SmemCopyAtomdQ = Copy_Atom; - static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); - static constexpr int kSmemPCount = size(SmemLayoutPdS{}); - static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); - static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); - static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); - static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) @@ -335,9 +287,6 @@ struct Flash_bwd_kernel_traits : public Base { + (!Is_V_in_regs ? kSmemKVSize + kSmemdSSize + kSmemPSize : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); - static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 - + kSmemdSSize + kSmemPSize; - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h b/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h deleted file mode 100644 index 01ea212b452c47..00000000000000 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernel_traits_sm90.h +++ /dev/null @@ -1,161 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include - -namespace pytorch_flash{ - -using namespace cute; - -template -struct Flash_kernel_traits_sm90 { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using Element = elem_type; - static constexpr bool Has_cp_async = true; -#else - using Element = cutlass::half_t; - static constexpr bool Has_cp_async = false; -#endif - - using ElementAccum = float; - using index_t = uint32_t; - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - using MMA_Atom_Arch = std::conditional_t< - std::is_same_v, - MMA_Atom, - MMA_Atom - >; - using ValLayoutMNK = Layout>; -#else - using MMA_Atom_Arch = MMA_Atom; - using ValLayoutMNK = Layout>; -#endif - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#else - using SmemCopyAtom = Copy_Atom; - using SmemCopyAtomTransposed = Copy_Atom; -#endif -}; - -template > -struct Flash_fwd_kernel_traits : public Base { - using Element = typename Base::Element; - using ElementAccum = typename Base::ElementAccum; - using index_t = typename Base::index_t; - static constexpr bool Has_cp_async = Base::Has_cp_async; - using SmemCopyAtom = typename Base::SmemCopyAtom; - using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; - - static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; - static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; - - // The number of threads. - static constexpr int kNWarps = kNWarps_; - static constexpr int kNThreads = kNWarps * 32; - - static constexpr int kBlockM = kBlockM_; - static constexpr int kBlockN = kBlockN_; - static constexpr int kHeadDim = kHeadDim_; - static_assert(kHeadDim % 32 == 0); - static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; - static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); - static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; - - using TiledMma = TiledMMA< - typename Base::MMA_Atom_Arch, - Layout,_1,_1>>, // 4x1x1 or 8x1x1 thread group - typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM - - using SmemLayoutAtomQ = decltype( - composition(Swizzle{}, - // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 - Layout>, - Stride, _1>>{})); - using SmemLayoutQ = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutKV = decltype(tile_to_shape( - SmemLayoutAtomQ{}, - Shape, Int>{})); - - using SmemLayoutAtomVtransposed = decltype( - composition(Swizzle{}, - // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 - Layout, Int>, - Stride<_1, Int>>{})); - using SmemLayoutVtransposed = decltype(tile_to_shape( - SmemLayoutAtomVtransposed{}, - Shape, Int>{})); - // Maybe the VtransposeNoSwizzle just needs to have the right shape - // And the strides don't matter? - using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); - - using SmemLayoutAtomO = decltype( - composition(Swizzle{}, - Layout, Int>, - Stride, _1>>{})); - using SmemLayoutO = decltype(tile_to_shape( - SmemLayoutAtomO{}, - Shape, Int>{})); - using SmemCopyAtomO = Copy_Atom; - - static constexpr int kSmemQCount = size(SmemLayoutQ{}); - static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; - static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); - static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); - static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; - - static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); - static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); - // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. - // For example, for d=128, smem is split into 2 "pages", each page takes care of columns - // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, - // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, - // to the same banks. - static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); - using GmemLayoutAtom = Layout, Int>, - Stride, _1>>; - - // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading - // from the same address by the same threadblock. This is slightly faster. - using Gmem_copy_struct = std::conditional_t< - Has_cp_async, - SM80_CP_ASYNC_CACHEGLOBAL, - DefaultCopy - >; - using GmemTiledCopyQKV = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per read - using GmemTiledCopyO = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtom{}, - Layout>{})); // Val layout, 8 vals per store - static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; - static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); - using GmemLayoutAtomP = Layout, Int>, - Stride, _1>>; - - using GmemTiledCopyP = decltype( - make_tiled_copy(Copy_Atom{}, - GmemLayoutAtomP{}, - Layout>{})); // Val layout, 8 vals per store - -}; -} // namespace pytorch_flash -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu index 247b359b052199..63a80c4d2062fc 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim128(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu index 54ba9b1d016578..720f54343a4693 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim128_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim128(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu index 351df04f7bd8b3..04aa184a6f78c1 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim160(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim160(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu index 057023e3be16ac..979082162997ad 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim160_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim160(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim160(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu index f772b3c75a4d52..76ac4426f0390e 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim192(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu index 91deb5f3e88e5a..d0a05f597219c4 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim192_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim192(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim192(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu index bf11ee849e1bc3..14ce1a9a450fc6 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim224(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim224(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu index 59a062829d468b..259c84cf8cdaaf 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim224_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim224(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim224(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu index 48150fabcd61f5..1767b60f7908bb 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim256(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu index f24074782bf7da..6381904f7b5b72 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim256_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim256(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu index 8724f83e900719..bd47a37e7f6e36 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim32(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu index aca37f6dfa07e7..ae046260c3706f 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim32_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim32(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim32(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu index ce1c12768d75bf..42314aac9d2a2d 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim64(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu index 5f901a7b3243f0..616c784f7524ca 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim64_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim64(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu index a0dc45eea3c887..6eccc4f455ad04 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_bf16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim96(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu index 083828ee67f9b3..54e455b81a36db 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/flash_bwd_hdim96_fp16_sm80.cu @@ -8,7 +8,7 @@ namespace pytorch_flash{ template<> -void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) { - run_mha_bwd_hdim96(params, stream, configure); +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim96(params, stream); } } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py index ee97a6a73cc050..ca1fe27f94903e 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py @@ -27,8 +27,8 @@ KERNEL_IMPL_TEMPLATE_BWD = """ template<> -void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{ - run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure); +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); }} """ diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h new file mode 100644 index 00000000000000..9cee154fbbd50e --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/mask.h @@ -0,0 +1,213 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +namespace pytorch_flash { + +using namespace cute; + +template +__forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +__forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +__forceinline__ __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor &tensor, Tensor const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) +{ + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const float alibi_slope; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template + __forceinline__ __device__ void apply_mask(Tensor &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + +}; + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh index 472d6b211f052c..bed362bdd0c8ea 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh @@ -11,7 +11,7 @@ struct ull2 { unsigned long long y; }; -inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { uint2 *res; unsigned long long tmp; asm ("mul.wide.u32 %0, %1, %2;\n\t" @@ -21,7 +21,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { return *res; } -inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); @@ -30,7 +30,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { return ret; } -inline __device__ uint4 philox(unsigned long long seed, +__forceinline__ __device__ uint4 philox(unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; @@ -51,117 +51,3 @@ inline __device__ uint4 philox(unsigned long long seed, } } // namespace flash - -namespace { - -class Philox { -public: - __device__ inline Philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) - : STATE(0) - , seed_(seed) - , offset_(offset) - , key(reinterpret_cast(seed)) { - //key.x = (unsigned int)seed; - //key.y = (unsigned int)(seed >> 32); - //counter = make_uint4(0, 0, 0, 0); - //counter.z = (unsigned int)(subsequence); - //counter.w = (unsigned int)(subsequence >> 32); - //STATE = 0; - //incr_n(offset / 4); - - // key = reinterpret_cast(seed); - ull2 * tmp = reinterpret_cast(&counter); - tmp->x = offset / 4; - tmp->y = subsequence; - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w); - // } - } - __device__ inline uint4 operator()() { - // // if (STATE == 0) { - // uint4 counter_ = counter; - // uint2 key_ = key; - // // 7-round philox - // #pragma unroll - // for (int i = 0; i < 6; i++) { - // counter_ = pytorch_flash::philox_single_round(counter_, key_); - // key_.x += (kPhilox10A); - // key_.y += (kPhilox10B); - // } - // // output = philox_single_round(counter_, key_); - // uint4 output = pytorch_flash::philox_single_round(counter_, key_); - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w); - // // } - // incr(); - // // } - // // return a float4 directly - // // unsigned long ret; - // // switch(STATE) { - // // case 0: ret = output.x; break; - // // case 1: ret = output.y; break; - // // case 2: ret = output.z; break; - // // case 3: ret = output.w; break; - // //} - // // STATE = (STATE + 1) % 4; - // return output; - return pytorch_flash::philox(seed_, offset_, offset_); - } - -private: - unsigned long long offset_, seed_; - struct ull2 { - uint64_t x; - uint64_t y; - }; - uint4 counter; - // uint4 output; - const uint2 key; - unsigned int STATE; - __device__ inline void incr_n(unsigned long long n) { - unsigned int nlo = (unsigned int)(n); - unsigned int nhi = (unsigned int)(n >> 32); - counter.x += nlo; - if (counter.x < nlo) - nhi++; - counter.y += nhi; - if (nhi <= counter.y) - return; - if (++counter.z) - return; - ++counter.w; - } - - __device__ uint4 incr128 (uint4 ctr) - { - uint4 res; - asm ("add.cc.u32 %0, %4, %8;\n\t" - "addc.cc.u32 %1, %5, %9;\n\t" - "addc.cc.u32 %2, %6, %10;\n\t" - "addc.u32 %3, %7, %11;\n\t" - : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) - : "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w), - "n"(1), "n"(0), "n"(0), "n"(0)); - return res; - } - - __device__ inline void incr() { - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - counter = incr128(counter); - // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w); - // } - } - - static const unsigned long kPhilox10A = 0x9E3779B9; - static const unsigned long kPhilox10B = 0xBB67AE85; - // static const unsigned long kPhiloxSA = 0xD2511F53; - // static const unsigned long kPhiloxSB = 0xCD9E8D57; -}; - -} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h b/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h new file mode 100644 index 00000000000000..12dc1746c80878 --- /dev/null +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace pytorch_flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h index 239a8114b68b7b..dec2065ec400a6 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h @@ -1,34 +1,15 @@ /****************************************************************************** - * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * + * Copyright (c) 2024, Tri Dao. ******************************************************************************/ #pragma once #include -#include + +#include + +#include + #include #include @@ -39,7 +20,7 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// template -__device__ inline void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); @@ -54,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor const &tensor, Te } template -__device__ inline void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { CUTE_STATIC_ASSERT_V(size(dst) == size(src)); #pragma unroll for (int i = 0; i < size(dst); i++){ @@ -63,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor &dst, Tensor -__device__ inline void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { thread_reduce_(tensor, summary, op); quad_allreduce_(summary, summary, op); } template -__device__ inline void reduce_max(Tensor const& tensor, Tensor &max){ +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ MaxOp max_op; reduce_(tensor, max, max_op); } -template -__device__ inline void reduce_sum(Tensor const& tensor, Tensor &sum){ +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ SumOp sum_op; - reduce_(tensor, sum, sum_op); + thread_reduce_(tensor, sum, sum_op); } // Apply the exp to all the elements. template -inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -104,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor &tensor, Tensor // Apply the exp to all the elements. template -inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); @@ -134,171 +115,67 @@ inline __device__ void max_scale_exp2_sum(Tensor &tensor, Tens } } -template -inline __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, - const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; - #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= max_seqlen_k) { - // Without the "make_coord" we get wrong results - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - tensor(mi, make_coord(j, nj)) = -INFINITY; - } - } - } - } -} +//////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void apply_mask_local(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride, - const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout::rank == 2, "Only support 2D Tensor"); - const int lane_id = threadIdx.x % 32; - // const int row_idx_offset = row_idx_offset_ + lane_id / 4; - const int row_idx_offset = row_idx_offset_; - const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; - #pragma unroll - for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { - const int row_idx_base = row_idx_offset + mi * warp_row_stride; - #pragma unroll - for (int i = 0; i < size<0, 0>(tensor); ++i) { - const int row_idx = row_idx_base + i * 8; - const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); - const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + pytorch_flash::template reduce_max(scores, row_max); + pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + pytorch_flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + pytorch_flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { - const int col_idx_base = col_idx_offset + nj * 8; + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; #pragma unroll - for (int j = 0; j < size<1, 0>(tensor); ++j) { - const int col_idx = col_idx_base + j; - if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { - tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; - } - } + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } } - // if (cute::thread0()) { - // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); - // print(tensor(make_coord(i, mi), _)); - // // print(tensor(_, j + nj * size<1, 0>(tensor))); - // } + pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + pytorch_flash::reduce_sum(scores, row_sum); } - } -} - -template -inline __device__ void apply_mask_causal(Tensor &tensor, const int col_idx_offset_, - const int max_seqlen_k, const int row_idx_offset_, - const int max_seqlen_q, const int warp_row_stride) { - // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 - apply_mask_local(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_, - max_seqlen_q, warp_row_stride, -1, 0); -} + }; -template -inline __device__ void apply_mask_causal_w_idx( - Tensor &tensor, Tensor const &idx_rowcol, - const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) -{ - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 2, "Only support 2D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); - CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); - #pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { - if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { - tensor(mi, ni) = -INFINITY; - } + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } - // if (cute::thread0()) { - // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); - // print(tensor(_, make_coord(j, ni))); - // // print(tensor(_, j + ni * size<1, 0>(tensor))); - // } - } -} - -template -inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, - unsigned long long seed, unsigned long long offset, - int block_row_start, int block_col_start, - int block_row_stride) { - // tensor has shape (8, MMA_M, MMA_N / 2) - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + return lse; }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } -} +}; } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h index 4aa8474028868d..ca12fa171bf989 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/static_switch.h @@ -14,6 +14,7 @@ /// some_function(...); /// }); /// ``` + #define BOOL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ if (COND) { \ @@ -25,6 +26,46 @@ } \ }() +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ @@ -36,7 +77,7 @@ } \ }() -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ +#define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ constexpr static int kHeadDim = 32; \ diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h index fc791b0b2107ea..2c8add318366ae 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/utils.h @@ -22,16 +22,17 @@ #include #include - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace pytorch_flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + template -inline __device__ uint32_t relu2(const uint32_t x); +__forceinline__ __device__ uint32_t relu2(const uint32_t x); template<> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -49,7 +50,7 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template<> -inline __device__ uint32_t relu2(const uint32_t x) { +__forceinline__ __device__ uint32_t relu2(const uint32_t x) { uint32_t res; const uint32_t zero = 0u; asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); @@ -62,10 +63,10 @@ inline __device__ uint32_t relu2(const uint32_t x) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 template -inline __device__ uint32_t convert_relu2(const float2 x); +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); template<> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -74,7 +75,7 @@ inline __device__ uint32_t convert_relu2(const float2 x) { } template<> -inline __device__ uint32_t convert_relu2(const float2 x) { +__forceinline__ __device__ uint32_t convert_relu2(const float2 x) { uint32_t res; const uint32_t a = reinterpret_cast(x.x); const uint32_t b = reinterpret_cast(x.y); @@ -88,20 +89,20 @@ inline __device__ uint32_t convert_relu2(const float2 x) { template struct MaxOp { -__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } }; template <> struct MaxOp { // This is slightly faster -__device__ inline float operator()(float const &x, float const &y) { return max(x, y); } +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct SumOp { -__device__ inline T operator()(T const & x, T const & y) { return x + y; } +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -110,7 +111,7 @@ template struct Allreduce { static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); template - static __device__ inline T run(T x, Operator &op) { + static __device__ __forceinline__ T run(T x, Operator &op) { constexpr int OFFSET = THREADS / 2; x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); return Allreduce::run(x, op); @@ -122,7 +123,7 @@ struct Allreduce { template<> struct Allreduce<2> { template -static __device__ inline T run(T x, Operator &op) { +static __device__ __forceinline__ T run(T x, Operator &op) { x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); return x; } @@ -134,7 +135,7 @@ template -inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { @@ -161,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 template -inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, - ThrCopy smem_thr_copy_B) { +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K @@ -183,42 +184,48 @@ inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB // Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) template -inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(rank(acc_layout))::value == 3); auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) - // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting - // "int_tuple.hpp(74): error: conversion to inaccessible base class" - // return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); - return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l))); + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); }; //////////////////////////////////////////////////////////////////////////////////////////////////// -// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) -// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. template -inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { using X = Underscore; - static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); - static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); static_assert(mma_shape_K == 8 || mma_shape_K == 16); - constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; - auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) - // TD [2023-08-13]: Same error as above on Cutlass 3.2 - // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), - // get<0, 1>(l), - // get<1, 1, 1>(l)); - return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))), - get<1>(get<0>(l)), - get<1>(get<1>(get<1>(l)))); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); }; //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ auto convert_type(Tensor const &tensor) { +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; constexpr int numel = decltype(size(tensor))::value; cutlass::NumericArrayConverter convert_op; @@ -230,7 +237,7 @@ inline __device__ auto convert_type(Tensor const &tensor) { //////////////////////////////////////////////////////////////////////////////////////////////////// template -inline __device__ void relu_(Tensor &tensor) { +__forceinline__ __device__ void relu_(Tensor &tensor) { constexpr int numel = decltype(size(tensor))::value; static_assert(numel % 2 == 0); using value_t = typename Engine::value_type; @@ -246,7 +253,7 @@ inline __device__ void relu_(Tensor &tensor) { // On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction template -inline __device__ auto convert_type_relu(Tensor const &tensor) { +__forceinline__ __device__ auto convert_type_relu(Tensor const &tensor) { using From_type = typename Engine::value_type; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); @@ -288,7 +295,7 @@ void cp_async_wait() { template -inline __device__ void copy(TiledCopy tiled_copy, Tensor const &S, +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0) { CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); @@ -357,7 +364,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor const template -inline __device__ void copy_w_min_idx(Tensor const &S, +__forceinline__ __device__ void copy_w_min_idx(Tensor const &S, Tensor &D, Tensor const &identity_MN, Tensor const &predicate_K, const int max_MN=0, const int min_MN=0) { @@ -384,137 +391,4 @@ inline __device__ void copy_w_min_idx(Tensor const &S, //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 421bc83ebed432..e2ea560b6afc6d 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -242,7 +242,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) return true; } -bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90( +bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89( sdp_params const& params, bool debug) { // Flash Attention will raise an error in the backward pass if the head_dim @@ -252,11 +252,19 @@ bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90( auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm86_or_sm89 = check_sm_version(dprops); bool is_head_dim_gt192 = params.query.sym_size(-1) > 192; - if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt192) { + bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224; + bool is_dropout = params.dropout > 0.0; + // head_dim size in (192, 224] is not supported on sm86 and sm89 + bool cond1 = is_head_dim_gt192 && is_head_dim_lte224; + // head_dim size > 224 and is_dropout is not supported on sm86 and sm89 + bool cond2 = params.query.sym_size(-1) > 224 && is_dropout; + if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) { if (debug) { TORCH_WARN( - "Flash attention currently doesn't support training with head_dim greater than 192 on gpu architectures in the range[sm86, sm89].", - "Attempting to run with head_dim: ", + "Flash attention currently doesn't support training with head_dim ∈ (192, 224] or " + "(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].", + "Attempting to run with dropout set to: ", params.dropout, + "and head_dim: ", params.query.sym_size(-1), " on a sm ", dprops->major, ".", dprops->minor, " gpu."); } @@ -467,7 +475,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { check_for_attn_mask, check_head_dim_size_flash, check_flash_attention_hardware_support, - check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90, + check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89, check_flash_causal_non_square_seqlens, check_dtypes_low_precision); for (auto& constraint : general_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 61999bc706c693..24eebee7a75ab5 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -106,10 +106,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, const float softmax_scale, bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -311,13 +312,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. - const int max_seqlen_q, + c10::optional &alibi_slopes_, // num_heads or b x num_heads + int max_seqlen_q, const int max_seqlen_k, const float p_dropout, const float softmax_scale, const bool zero_tensors, - const bool is_causal, - const int window_size_left, + bool is_causal, + int window_size_left, int window_size_right, const bool return_softmax, c10::optional gen_) { @@ -343,11 +345,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads const float p_dropout, // probability to drop const float softmax_scale, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { check_gpu_arch(); @@ -630,14 +634,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 + c10::optional &alibi_slopes_, // num_heads or b x num_heads const int max_seqlen_q, const int max_seqlen_k, // max sequence length to choose the kernel const float p_dropout, // probability to drop const float softmax_scale, const bool zero_tensors, const bool is_causal, - const int window_size_left, + int window_size_left, int window_size_right, + const bool deterministic, const at::Tensor philox_seed, const at::Tensor philox_offset) { TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm"); diff --git a/test/test_transformers.py b/test/test_transformers.py index af14b06c21048d..e752ba1fa41131 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1328,13 +1328,17 @@ class TestSDPAFailureModes(NNTestCase): _do_cuda_non_default_stream = True @onlyCUDA - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, - "Does not support fused SDPA or not SM86+ hardware") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, + "Does not support fused SDPA or not SM86+ hardware", + ) @parametrize("head_dim", [193, 204, 256]) - def test_flash_backward_failure_sm86plus(self, device, head_dim: int): + @parametrize("dropout_p", [0.0, 0.2]) + def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype) - # See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + # See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in + # pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h size = (2, 2, 4, head_dim) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) @@ -1351,8 +1355,15 @@ def test_flash_backward_failure_sm86plus(self, device, head_dim: int): q = make_tensor(size, requires_grad=True) k = make_tensor(size, requires_grad=True) v = make_tensor(size, requires_grad=True) - self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( - q, k, v, None, 0.0, False)) + if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0): + self.assertRaises( + RuntimeError, + lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, dropout_p, False + ), + ) + else: + flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False) @onlyCUDA def test_dispatch_fails_no_backend(self, device): @@ -1589,7 +1600,6 @@ def test_nested_fails_on_padding_head_dim(self, device): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) - @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device, "Current platform does not support fused SDPA or is an SM80+ device.") @@ -1670,37 +1680,35 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, is_causal=True)) -def _get_block_size(device, head_dim, is_causal): +def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel - # Mask is only interesting when we are setting dropout - is_dropout = True assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) is_sm80 = major == 8 and minor == 0 is_sm90 = major == 9 and minor == 0 if head_dim <= 32: - return 128, 128 + return 128 if head_dim <= 64: - return (128, 128) if not is_dropout else (128, 64) + return 128 if not is_dropout else 64 elif head_dim <= 96: - return (64, 64) if (is_sm8x and is_causal) else (128, 64) + return 64 elif head_dim <= 128: if is_sm8x: - return (64, 64) if (not is_dropout and is_causal) else (128, 32) + return 64 if (not is_dropout and is_causal) else 32 else: - return 128, (64 if not is_dropout else 32) + return 64 if not is_dropout else 32 elif head_dim <= 160: if is_sm8x: - return (128, 64) if not is_causal else (64, 64) + return 64 else: - return 128, 32 + return 32 elif head_dim <= 192: - return (128, 64) if not is_dropout else (64, 64) + return 64 elif head_dim <= 224: - return (128, 64) if (is_sm80 or is_sm90) else (64, 64) + return 64 elif head_dim <= 256: - return (128, 64) if is_sm80 else (64, 64) + return 64 def pad_last_dim(input_tensor, alignment_size, slice: bool = False): @@ -1963,7 +1971,114 @@ class TestSDPACudaOnly(NNTestCase): _do_cuda_memory_leak_check = True _do_cuda_non_default_stream = True - def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False): + # TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now + def normalize_flash_attn_S( + self, + attn_unnorm, + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + is_dropout=False, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + scale=None, + ): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k, v: (batch_size, seqlen_k, nheads, head_dim) + key_padding_mask: (batch_size, seqlen_q) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + Output: + softmax_lse: (batch_size, nheads, seqlen_q) + softmax_max: (batch_size, nheads, seqlen_q) + """ + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if causal: + window_size = (window_size[0], 0) + q, k, v = q.float(), k.float(), v.float() + _, seqlen_q, _, head_dim = q.shape + seqlen_k = k.shape[1] + b = q.shape[0] + from torch.nn.attention.bias import _calculate_scale + scale = _calculate_scale(head_dim, scale) + scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1)) + if key_padding_mask is not None: + scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf")) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = self.construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias.to(dtype=scores.dtype) + block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) + scores_block = scores.split(block_size_n, dim=-1) + lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) + lse = torch.logsumexp(lse_block, dim=-1) + # lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf + # so that when we do torch.exp(m - lse), we get 0.0 instead of NaN. + lse[lse == float("-inf")] = float("inf") + scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) + cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) + attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) + attn_norm = torch.cat( + [ + a * (torch.exp(m - lse)).unsqueeze(-1) + for a, m in zip(attn_unnorm_block, cummax_block) + ], + dim=-1, + ) + if query_padding_mask is not None: + attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0) + # attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return attn_norm.to(dtype=attn_unnorm.dtype) + + def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): + # row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else key_padding_mask.sum(-1).view(-1, 1, 1, 1) + # else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else query_padding_mask.sum(-1).view(-1, 1, 1, 1) + # else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + def convert_flash_attn_S_to_softmax( + self, + S, + seqlen_q, + seqlen_k, + query_padding_mask, + key_padding_mask, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + ): """FlashAttention stores the S matrix in a different way. Arguments: S: (batch_size, nheads, seqlen_q, seqlen_k) @@ -1972,53 +2087,45 @@ def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mas """ if TEST_WITH_ROCM: return S - - b, h, seqlen_q, seqlen_k = S.shape - warps_n = 4 - blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal) - nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m - nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n - mmas_n = (blocksize_n + 16 - 1) // 16 - - # Reshape S using PyTorch native functions - S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n) - S_flat = S_flat.permute(0, 1, 2, 4, 3, 5) - S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n)) - S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2) - S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11) - S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) * - warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2)) + b = S.shape[0] if causal: - causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1) - S_converted.masked_fill_(causal_mask, 0.0) + window_size = (window_size[0], 0) + seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] + S_converted = S + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = self.construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + S.device, + ) + local_mask = F.pad( + local_mask, + (0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), + value=True, + ) + S_converted = S_converted.masked_fill(local_mask, 0.0) + # Need to zero out things not in attention_mask in case S was initialized with random values # and some of those values aren't overwritten. - seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q + seqlen_q_og = ( + query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded + ) if query_padding_mask is not None: - if seqlen_q_og < seqlen_q: - query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og)) - else: - query_padding_mask = query_padding_mask[:, :seqlen_q] - q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1) - S_converted = S_converted.masked_fill(q_mask_fill, 0.0) + query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) + # S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0) seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k if key_padding_mask is not None: - if seqlen_k_og < seqlen_k: - key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og)) - else: - key_padding_mask = key_padding_mask[:, :seqlen_k] - k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1]) - S_converted = S_converted.masked_fill(k_mask_fill, 0.0) - if seqlen_q_og < seqlen_q: - S_converted = S_converted[:, :, :seqlen_q_og, :] - else: - S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q)) - if seqlen_k_og < seqlen_k: - S_converted = S_converted[:, :, :, :seqlen_k_og] - else: - S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k)) - return S_converted + key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) + S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0) + # S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) + S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) + return S_converted[:, :, :seqlen_q, :seqlen_k] @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) @@ -2370,28 +2477,29 @@ def test_sdp_choice_with_determinism(self, device, warn_only): query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) with use_deterministic_algorithims(True, warn_only=warn_only): - # Note that this should swith to a testing version with we remove old context manager with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]): assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value - @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA") + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") + @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("warn_only", [True, False]) - def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only): + def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel): batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64 shape = SdpaShape(batch_size, num_heads, seq_len, head_dim) - make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True) + make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True) query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape) + kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention" warning_context = ( self.assertWarnsRegex( UserWarning, - "Memory Efficient attention defaults to a non-deterministic algorithm.", + f"{kernel_name} defaults to a non-deterministic algorithm.", ) if warn_only else contextlib.nullcontext() ) with use_deterministic_algorithims(True, warn_only=warn_only): - with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]): + with sdpa_kernel(backends=[fused_kernel]): with warning_context: torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward() @@ -2710,8 +2818,6 @@ def is_power_of_2(n): is_dropout = dropout_p > 0.0 if not is_dropout: - # Problem: We pad sizes in the composite region of the top level SDPA. But we need the - # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]): out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale) with sdpa_kernel(backends=[SDPBackend.MATH]): @@ -2722,6 +2828,8 @@ def is_power_of_2(n): out_lp_ref = F.scaled_dot_product_attention( query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale) else: + # Problem: We pad sizes in the composite region of the top level SDPA. But we need the + # Debug mask when have dropout. So I am going to manualy pad up here when testing dropout q_padded, q_og_size = pad_last_dim(query, 8) k_padded, k_og_size = pad_last_dim(key, 8) v_padded, v_og_size = pad_last_dim(value, 8) @@ -2740,9 +2848,14 @@ def is_power_of_2(n): batch_size, seq_len_k, device=device, dtype=torch.bool) softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, + dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, causal=is_causal)[:, :, :seq_len_q, :seq_len_k] dropout_mask = softmax_mask >= 0 + # attn_unnorm = softmax_mask.abs() + # attn = self.normalize_flash_attn_S(attn_unnorm, q_padded, + # k_padded, v_padded, query_padding_mask, + # key_padding_mask, None, True, is_causal, scale=scale) + # High Precision Math Reference out_ref = torch.ops.aten._scaled_dot_product_attention_math( query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0] @@ -2823,7 +2936,8 @@ def get_dropout_mask(output, fused_kernel, batch_size, n_heads, q_len, kv_len, d batch_size, seq_len_k, device=device, dtype=torch.bool) softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) + dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, + causal=is_causal)[:, :, :seq_len_q, :seq_len_k] dropout_mask = softmax_mask >= 0 return dropout_mask @@ -3178,7 +3292,7 @@ def rand_nt(sequence_list, num_heads, head_dim): key_padding_mask = key_padding_mask.to("cuda") softmax_mask = self.convert_flash_attn_S_to_softmax( - dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal) + dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal) dropout_mask = softmax_mask >= 0 nt_stack = [] for tensor_component in range(batch_size):