From b8c4928dd793eb98246e8ad06ff7d8892cde79d2 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Fri, 18 Apr 2025 05:43:27 -0700 Subject: [PATCH 01/20] Add seq_len_kv_cache --- .../collective/xe_flash_attn_epilogue.hpp | 6 +++--- .../collective/xe_flash_attn_mma.hpp | 4 ++-- .../kernel/xe_flash_attn_gemm.hpp | 10 ++++----- .../pvc_flash_attn_runner.hpp | 21 +++++++++++-------- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_epilogue.hpp index bae76f2c5a..72ec76e21d 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_epilogue.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_epilogue.hpp @@ -116,7 +116,7 @@ class CollectiveEpilogueAttention static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args, [[maybe_unused]] void *workspace) { - auto [batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape; + auto [batch, num_heads, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; auto tensorO = make_tensor(make_gmem_ptr(static_cast(args.ptr_O)), make_layout(make_shape(seq_len_qo, head_size_vo, batch * num_heads), @@ -179,7 +179,7 @@ class CollectiveEpilogueAttention(problem_shape).cumulative_length; int offset_o = num_heads * head_size_vo * qo_cumulative_length[l_coord]; diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index 384cc0e762..0fe5b2e469 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -181,7 +181,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, void *workspace) { (void)workspace; - auto [batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape; + auto [batch, num_heads, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads), args.dQ)); auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads), args.dK)); @@ -308,7 +308,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, if constexpr (!is_var_len) { return params; } else { - auto [batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape; + auto [batch, num_heads, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; auto qo_cumulative_length = get<2>(problem_shape).cumulative_length; auto kv_cumulative_length = get<3>(problem_shape).cumulative_length; diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 6e05bb15c4..13f0d16aa9 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -53,7 +53,7 @@ class GemmUniversalAttention { // using ProblemShape = ProblemShape_; - static_assert(rank(ProblemShape{}) == 6, "ProblemShape{} should be "); + static_assert(rank(ProblemShape{}) == 7, "ProblemShape{} should be "); // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; @@ -190,7 +190,7 @@ class GemmUniversalAttention { static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } CUTLASS_DEVICE - Shape get_logical_problem_shape(ProblemShape const& problem_shape, int const& batch) { + Shape get_logical_problem_shape(ProblemShape const& problem_shape, int const& batch) { if constexpr (is_var_len) { return cutlass::fmha::collective::apply_variable_length(problem_shape, batch); } else { @@ -207,8 +207,8 @@ class GemmUniversalAttention { // Separate out problem shape for convenience auto batch = get<0>(params.problem_shape); auto num_heads = get<1>(params.problem_shape); - auto head_size_qk = get<4>(params.problem_shape); - auto head_size_vo = get<5>(params.problem_shape); + auto head_size_qk = get<5>(params.problem_shape); + auto head_size_vo = get<6>(params.problem_shape); // Preconditions static_assert(cute::rank(StrideQ{}) == 3, "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * num_heads]."); static_assert(cute::rank(StrideK{}) == 3, "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * num_heads]."); @@ -241,7 +241,7 @@ class GemmUniversalAttention { // logical_problem_shape = [batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] auto logical_problem_shape = get_logical_problem_shape(params.problem_shape, batch_coord); - auto [seq_len_qo, seq_len_kv] = select<2, 3>(logical_problem_shape); + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = select<2, 3, 4>(logical_problem_shape); // Calculate the seq_len_idx (blk_m_coord * get<0>(WorkgroupTileShape{})) and check if it is still // within bounds of the actual seq_len_qo (get<2>(logical_problem_shape)). diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 40d7adb016..98dd529779 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -62,12 +62,12 @@ struct Options { bool varlen = false; std::string scheduler; - int batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo, iterations; + int batch, num_heads, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo, iterations; float softmax_scale; Options() : help(false), error(false), is_causal(false), varlen(false), batch(32), num_heads(16), seq_len_qo(512), head_size_qk(128), - seq_len_kv(512), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} + seq_len_kv(512), seq_len_kv_cache(0), head_size_vo(128), iterations(100), softmax_scale(1.f), scheduler("Individual") {} // Parses the command line void parse(int argc, char const **args) { @@ -92,6 +92,7 @@ struct Options { cmd.get_cmd_line_argument("num_heads", num_heads, 16); cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 512); cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, seq_len_qo); + cmd.get_cmd_line_argument("seq_len_kv_cache", seq_len_kv_cache, 0); cmd.get_cmd_line_argument("head_size_vo", head_size_vo, 128); cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); cmd.get_cmd_line_argument("iterations", iterations, 100); @@ -112,6 +113,7 @@ struct Options { << " --num_heads= Sets the Number of Attention Heads of the Multi-Head Self Attention module\n" << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --seq_len_kv_cache= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" << " --iterations= Iterations\n\n"; @@ -181,7 +183,7 @@ template struct ExampleRunner { get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; } - auto [batch, num_heads, head_size_qk, head_size_vo] = cute::select<0,1,4,5>(problem_size); + auto [batch, num_heads, head_size_qk, head_size_vo] = cute::select<0,1,5,6>(problem_size); int seq_len_qo, seq_len_kv; int offset_q = 0; @@ -369,8 +371,8 @@ template struct ExampleRunner { get<2>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q}; get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; - get<4>(problem_size_for_launch) = get<4>(problem_size); get<5>(problem_size_for_launch) = get<5>(problem_size); + get<6>(problem_size_for_launch) = get<6>(problem_size); get<0>(problem_size_for_launch) = get<0>(problem_size); get<1>(problem_size_for_launch) = get<1>(problem_size); @@ -380,7 +382,7 @@ template struct ExampleRunner { /// Initialize operands to be used in the GEMM and reference GEMM ProblemShapeType initialize(const Options &options) { auto problem_shape_in = - cute::make_tuple(options.batch, options.num_heads, options.seq_len_qo, options.seq_len_kv, options.head_size_qk, options.head_size_vo); + cute::make_tuple(options.batch, options.num_heads, options.seq_len_qo, options.seq_len_kv, options.seq_len_kv_cache, options.head_size_qk, options.head_size_vo); ProblemShapeType problem_shape; decltype(problem_shape_in) problem_size; @@ -395,7 +397,7 @@ template struct ExampleRunner { problem_shape = problem_shape_in; } - auto [batch, num_heads, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_size; + auto [batch, num_heads, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_size; stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, batch * num_heads)); stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, batch * num_heads)); @@ -519,7 +521,8 @@ template struct ExampleRunner { double gbps_pv = 2.0 * options.batch * options.num_heads * (options.seq_len_kv * options.seq_len_qo + options.seq_len_qo * options.head_size_vo); double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); std::cout << "Batch: " << options.batch << "\tNumHeads: " << options.num_heads << "\tSeq Length QO: " << options.seq_len_qo - << "\tSeq Length KV: " << options.seq_len_kv << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo + << "\tSeq Length KV: " << options.seq_len_kv << "\tSeq Length KV Cache: " << options.seq_len_kv_cache + << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") << "\t Scheduler: " << options.scheduler; printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); @@ -562,9 +565,9 @@ template struct FMHAConfig GmemTiledCopyStore>; using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::CollectiveSoftmaxEpilogue; - using ProblemShapeRegular = cute::tuple; + using ProblemShapeRegular = cute::tuple; using namespace cutlass::fmha::collective; - using ProblemShapeVarlen = cute::tuple; + using ProblemShapeVarlen = cute::tuple; using ProblemShapeType = std::conditional_t; // Mainloop From fcde8431e1eb858875c4175beb0f45a076d42413 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Fri, 18 Apr 2025 07:43:39 -0700 Subject: [PATCH 02/20] Add KV cache and update verify kernel --- .../collective/xe_flash_attn_mma.hpp | 35 ++++++- .../pvc_flash_attn_runner.hpp | 91 ++++++++++++++++--- 2 files changed, 110 insertions(+), 16 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index 0fe5b2e469..d87783f85c 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -163,12 +163,18 @@ struct CollectiveMmaAttention, ProblemShapeType_, StrideK dK; ElementV const *ptr_V; StrideV dV; + ElementK const* ptr_K_cache; + StrideK dK_cache; + ElementV const* ptr_V_cache; + StrideV dV_cache; }; struct Params { XE_Copy_Q gmem_tiled_copy_q; XE_Copy_K gmem_tiled_copy_k; XE_Copy_V gmem_tiled_copy_v; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; }; // @@ -186,11 +192,16 @@ struct CollectiveMmaAttention, ProblemShapeType_, auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads), args.dQ)); auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads), args.dK)); auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads), args.dV)); + auto tensorK_cache = make_tensor(make_gmem_ptr(args.ptr_K_cache), make_layout(make_shape(seq_len_kv_cache, head_size_qk, batch * num_heads), args.dK_cache)); + auto tensorV_cache = make_tensor(make_gmem_ptr(args.ptr_V_cache), make_layout(make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads), args.dV_cache)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; - - return Params{copyQ, copyK, copyV}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; } template @@ -316,6 +327,8 @@ struct CollectiveMmaAttention, ProblemShapeType_, int offset_q = num_heads * head_size_qk * qo_cumulative_length[l_coord]; int offset_k = num_heads * head_size_qk * kv_cumulative_length[l_coord]; int offset_v = num_heads * head_size_vo * kv_cumulative_length[l_coord]; + int offset_k_cache = num_heads * head_size_qk * seq_len_kv_cache; + int offset_v_cache = num_heads * head_size_vo * seq_len_kv_cache; auto q_traits = static_cast(params.gmem_tiled_copy_q); ElementQ* q_ptr = (ElementQ*)q_traits.base_ptr; @@ -326,6 +339,12 @@ struct CollectiveMmaAttention, ProblemShapeType_, auto v_traits = static_cast(params.gmem_tiled_copy_v); ElementV* v_ptr = (ElementV*)v_traits.base_ptr; + auto k_traits_cache = static_cast(params.gmem_tiled_copy_k_cache); + ElementK* k_cache_ptr = (ElementK*)k_traits_cache.base_ptr; + + auto v_traits_cache = static_cast(params.gmem_tiled_copy_v_cache); + ElementV* v_cache_ptr = (ElementV*)v_traits_cache.base_ptr; + auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk, num_heads); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); @@ -335,15 +354,25 @@ struct CollectiveMmaAttention, ProblemShapeType_, auto shape_v = make_shape(head_size_vo, static_cast(seq_len_kv), num_heads); StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); + auto shape_k_cache = make_shape(static_cast(seq_len_kv_cache), head_size_qk, num_heads); + StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + + auto shape_v_cache = make_shape(head_size_vo, static_cast(seq_len_kv_cache), num_heads); + StrideV stride_v_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k)); auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v)); + auto tensorK_cache = make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, shape_k_cache)); + auto tensorV_cache = make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; } } }; diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 98dd529779..8e813cfda2 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -156,12 +156,16 @@ template struct ExampleRunner { StrideQ stride_Q; StrideK stride_K; StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; StrideO stride_O; uint64_t seed = 0; cutlass::DeviceAllocation block_Q; cutlass::DeviceAllocation block_K; cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_K_cache; + cutlass::DeviceAllocation block_V_cache; cutlass::DeviceAllocation block_O; cutlass::DeviceAllocation block_ref_O; @@ -174,7 +178,7 @@ template struct ExampleRunner { // Methods // - bool verify(ProblemShapeType problem_size, bool is_causal) { + bool verify(ProblemShapeType problem_size, bool is_causal, bool use_kv_cache) { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<2>(problem_size)); @@ -184,11 +188,13 @@ template struct ExampleRunner { } auto [batch, num_heads, head_size_qk, head_size_vo] = cute::select<0,1,5,6>(problem_size); - int seq_len_qo, seq_len_kv; + int seq_len_qo, seq_len_kv, seq_len_kv_cache; int offset_q = 0; int offset_k = 0; int offset_v = 0; + int offset_k_cache = 0; + int offset_v_cache = 0; int offset_o = 0; // loop over the batch dimension to compute the output // to avoid the risk of running out of device memory @@ -197,19 +203,62 @@ template struct ExampleRunner { auto logical_problem_shape = cutlass::fmha::collective::apply_variable_length(problem_size, b); seq_len_qo = get<2>(logical_problem_shape); seq_len_kv = get<3>(logical_problem_shape); + seq_len_kv_cache = get<4>(logical_problem_shape); } else { seq_len_qo = get<2>(problem_size); seq_len_kv = get<3>(problem_size); + seq_len_kv_cache = get<4>(problem_size); } + int seq_len_kv_total = use_kv_cache ? (seq_len_kv_cache + seq_len_kv) : seq_len_kv; for (int h = 0; h < num_heads; h++) { cutlass::DeviceAllocation block_S; block_S.reset(seq_len_qo * seq_len_kv); + ElementK* k_ptr; + ElementV* v_ptr; + + if (use_kv_cache) { + cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); + cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); + + // Concatenate K_cache and K_new + syclcompat::memcpy( + block_K_concat.get(), + block_K_cache.get() + offset_k_cache, + seq_len_kv_cache * head_size_qk + ); + syclcompat::memcpy( + block_K_concat.get() + seq_len_kv_cache * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * head_size_qk + ); + + // Concatenate V_cache and V_new + syclcompat::memcpy( + block_V_concat.get(), + block_V_cache.get() + offset_v_cache, + seq_len_kv_cache * head_size_vo + ); + syclcompat::memcpy( + block_V_concat.get() + seq_len_kv_cache * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * head_size_vo + ); + syclcompat::wait(); + + k_ptr = block_K_concat.get() + offset_k_cache; + v_ptr = block_V_concat.get() + offset_v_cache; + } + else { + k_ptr = block_K.get() + offset_k; + v_ptr = block_V.get() + offset_v; + } + cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); - cutlass::TensorRef ref_K(block_K.get() + offset_k, LayoutK::packed({head_size_qk, seq_len_kv})); - cutlass::TensorRef ref_V(block_V.get() + offset_v, LayoutV::packed({seq_len_kv, head_size_vo})); - cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total })); + cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({ seq_len_kv_total, head_size_vo})); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, @@ -217,9 +266,9 @@ template struct ExampleRunner { 0.f, ref_S, ref_S, ElementAccumulator(0), 1, // batch_count seq_len_qo * head_size_qk, // batch_stride_Q - seq_len_kv * head_size_qk, // batch_stride_K - seq_len_qo * seq_len_kv, // batch_stride_S - seq_len_qo * seq_len_kv // batch_stride_S + seq_len_kv_total * head_size_qk, // batch_stride_K + seq_len_qo * seq_len_kv_total, // batch_stride_S + seq_len_qo * seq_len_kv_total // batch_stride_S ); syclcompat::wait(); @@ -295,8 +344,8 @@ template struct ExampleRunner { cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, 0.f, ref_O, ref_O, ElementAccumulator(0), 1, // batch_count - seq_len_qo * seq_len_kv, // batch_stride_P - seq_len_kv * head_size_vo, // batch_stride_V + seq_len_qo * seq_len_kv_total, // batch_stride_P + seq_len_kv_total * head_size_vo, // batch_stride_V seq_len_qo * head_size_vo, // batch_stride_O seq_len_qo * head_size_vo // batch_stride_O ); @@ -308,6 +357,8 @@ template struct ExampleRunner { offset_q += seq_len_qo * head_size_qk; offset_k += seq_len_kv * head_size_qk; offset_v += seq_len_kv * head_size_vo; + offset_k_cache += seq_len_kv_cache * head_size_qk; + offset_v_cache += seq_len_kv_cache * head_size_vo; offset_o += seq_len_qo * head_size_vo; } } @@ -324,6 +375,7 @@ template struct ExampleRunner { template auto initialize_varlen(const ProblemShape& problem_size, const bool VarlenSame = true) { int num_batches = get<0>(problem_size); + int seq_len_kv_cache = get<4>(problem_size); // generate Q as --b times // gaussian (--Q, --Q / 2) sampled positive @@ -366,11 +418,13 @@ template struct ExampleRunner { get<0>(problem_size_for_init) = 1; get<2>(problem_size_for_init) = total_seqlen_q; get<3>(problem_size_for_init) = total_seqlen_kv; + get<4>(problem_size_for_init) = seq_len_kv_cache; ProblemShapeType problem_size_for_launch; get<2>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q}; get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; + get<4>(problem_size_for_launch) = get<4>(problem_size); get<5>(problem_size_for_launch) = get<5>(problem_size); get<6>(problem_size_for_launch) = get<6>(problem_size); get<0>(problem_size_for_launch) = get<0>(problem_size); @@ -402,17 +456,23 @@ template struct ExampleRunner { stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, batch * num_heads)); stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, batch * num_heads)); stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, batch * num_heads)); + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, head_size_qk, batch * num_heads)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads)); stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, batch * num_heads)); block_Q.reset(batch * num_heads * seq_len_qo * head_size_qk); block_K.reset(batch * num_heads * seq_len_kv * head_size_qk); block_V.reset(batch * num_heads * seq_len_kv * head_size_vo); + block_K_cache.reset(batch * num_heads * seq_len_kv_cache * head_size_qk); + block_V_cache.reset(batch * num_heads * seq_len_kv_cache * head_size_vo); block_O.reset(batch * num_heads * seq_len_qo * head_size_vo); block_ref_O.reset(batch * num_heads * seq_len_qo * head_size_vo); initialize_block(block_Q, seed + 2023); initialize_block(block_K, seed + 2022); initialize_block(block_V, seed + 2021); + initialize_block(block_K_cache, seed + 2024); + initialize_block(block_V_cache, seed + 2025); if (!cumulative_seqlen_q.empty()) { device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); @@ -470,7 +530,11 @@ template struct ExampleRunner { typename GemmKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, problem_size, - {block_Q.get(), stride_Q, block_K.get(), stride_K, block_V.get(), stride_V}, + {block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_K_cache.get(), stride_K_cache, + block_V_cache.get(), stride_V_cache}, {options.softmax_scale}, {block_O.get(), stride_O}, hw_info}; @@ -498,7 +562,8 @@ template struct ExampleRunner { syclcompat::wait(); // Verify that the result is correct - bool passed = verify(problem_size, options.is_causal); + bool use_kv_cache = options.seq_len_kv_cache > 0; + bool passed = verify(problem_size, options.is_causal, use_kv_cache); std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; if (!passed) { @@ -567,7 +632,7 @@ template struct FMHAConfig using ProblemShapeRegular = cute::tuple; using namespace cutlass::fmha::collective; - using ProblemShapeVarlen = cute::tuple; + using ProblemShapeVarlen = cute::tuple; using ProblemShapeType = std::conditional_t; // Mainloop From 359dfa554241c782a5cc4ff16c92a30725a858e3 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Fri, 18 Apr 2025 09:18:37 -0700 Subject: [PATCH 03/20] Update mmaQK, mmaPV to handle KV cache, new --- .../collective/xe_flash_attn_mma.hpp | 20 +++++---- .../kernel/xe_flash_attn_gemm.hpp | 41 ++++++++++++++----- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index d87783f85c..05ee080178 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -206,11 +206,13 @@ struct CollectiveMmaAttention, ProblemShapeType_, template CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, FragSrc const &frag_src, - int const &k_tile_count, Params const ¶ms) { + int const &k_tile_count, Params const ¶ms, bool is_KV_cache) { + + auto& gmem_tiled_copy_k = is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k; int thread_idx = static_cast(ThreadIdxX()); auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); - auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx); + auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx); // Instantiate the MMA object TiledMmaK tiled_mma_k; TiledMmaQVO tiled_mma_q; @@ -227,7 +229,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, // Create fragments // TODO(Codeplay): fix this, this is probably not general Tensor tCrQ = make_tensor(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape()))); - Tensor tCrK = make_tensor(make_fragment_layout(params.gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); + Tensor tCrK = make_tensor(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape()))); // Retile registers for copies Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); @@ -265,14 +267,16 @@ struct CollectiveMmaAttention, ProblemShapeType_, for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { copy(params.gmem_tiled_copy_q, tQgQ(_,_,_,k_tile), tQrQ); - copy(params.gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK); + copy(gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK); cute::gemm(tiled_mma_q, accum, tCrQ, tCrK, frag_src); } } template CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV, - FragSrc const &frag_src, Params const ¶ms) { + FragSrc const &frag_src, Params const ¶ms, bool is_KV_cache) { + + auto& gmem_tiled_copy_v = is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v; int thread_idx = static_cast(ThreadIdxX()); // Instantiate the MMA object @@ -281,10 +285,10 @@ struct CollectiveMmaAttention, ProblemShapeType_, auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); Tensor tCgV = thread_mma.partition_B(gV); - Tensor tCrV = make_tensor(make_fragment_layout(params.gmem_tiled_copy_v, tCgV.shape())); + Tensor tCrV = make_tensor(make_fragment_layout(gmem_tiled_copy_v, tCgV.shape())); // Partition the copying of A and B tiles across the threads - auto gmem_thr_copy_V = params.gmem_tiled_copy_v.get_slice(thread_idx); + auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); @@ -310,7 +314,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, // // Mainloop // - copy(params.gmem_tiled_copy_v, tVgV, tVrV); + copy(gmem_tiled_copy_v, tVgV, tVrV); cute::gemm(tiled_mma, accum, tPr, tCrV, frag_src); } diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 13f0d16aa9..001ed9efd6 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -252,13 +252,20 @@ class GemmUniversalAttention { Tensor mQ_mkl = cute::get_pvc_tensor(make_shape(seq_len_qo, head_size_qk, (is_var_len ? 1 : batch) * num_heads)); //(m,k,l) Tensor mK_nkl = cute::get_pvc_tensor(make_shape(seq_len_kv, head_size_qk, (is_var_len ? 1 : batch) * num_heads)); //(n,k,l) Tensor mV_nkl = cute::get_pvc_tensor(make_shape(head_size_vo, seq_len_kv, (is_var_len ? 1 : batch) * num_heads)); //(n,k,l) + Tensor mK_cache_nkl = cute::get_pvc_tensor(make_shape(seq_len_kv_cache, head_size_qk, (is_var_len ? 1 : batch) * num_heads)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_pvc_tensor(make_shape(head_size_vo, seq_len_kv_cache, (is_var_len ? 1 : batch) * num_heads)); // (n_cache,k,l) + Tensor mQ_mk = mQ_mkl(_, _, blk_l_coord); // (m,k) Tensor mK_nk = mK_nkl(_, _, blk_l_coord); // (n,k) Tensor mV_nk = mV_nkl(_, _, blk_l_coord); // (n,k) + Tensor mK_cache_nk = mK_cache_nkl(_, _, blk_l_coord); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, blk_l_coord); // (n_cache, k) auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _ , _), Step{}); auto gV = local_tile(mV_nk, TileShapePV{}, make_coord(_, blk_n_coord, _), Step{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); + auto gV_cache = local_tile(mV_cache_nk, TileShapePV{}, make_coord(_, blk_n_coord, _), Step{}); const int seq_coord = cute::min(seq_len_qo, blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M); const int l_coord = blk_l_coord; @@ -266,8 +273,10 @@ class GemmUniversalAttention { const int causal_seq_len = cute::min(seq_len_kv, seq_coord) + QK_SG_M; const int non_causal_seq_len = seq_len_kv; - const int nblock_limit = CausalMask ? cute::ceil_div(causal_seq_len, QK_BLK_N) + const int nblock_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); + const int nblock_new = CausalMask ? cute::ceil_div(causal_seq_len, QK_BLK_N) : cute::ceil_div(non_causal_seq_len, QK_BLK_N); + int nblock_limit = nblock_cache + nblock_new; auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, batch_coord); @@ -313,22 +322,31 @@ class GemmUniversalAttention { // MAIN LOOP: loop over K and V, perform fused attention + online softmax for (int nblock = 0; nblock < nblock_limit - static_cast(CausalMask); nblock++) { barrier_arrive(barrier_scope); - // 1) Load K (performed inside mmaQK) + + bool is_KV_cache = (nblock < nblock_cache); + int local_block = is_KV_cache ? nblock : (nblock - nblock_cache); + + // 1) Load KV (performed inside mmaQK) + auto gK_ = is_KV_cache ? gK_cache(_, _, local_block, _) : gK(_, _, local_block, _); + auto gV_ = is_KV_cache ? gV_cache(_, _, local_block) : gV(_, _, local_block); + // 2) Create Tensor S Tensor tSr = make_tensor(Shape, Int, Int>{}); clear(tSr); // 3) Perform GEMM S = Q*K - collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + collective_mma.mmaQK(tSr, gQ, gK_, tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, is_KV_cache); // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock)); + // 4) Fused softmax CollectiveSoftmaxEpilogue softmax(params.softmax); softmax(nblock == 0, tSr, max_reg, sum_reg, out_reg); - collective_mma.mmaPV(out_reg, tSr, gV(_, _ , nblock), out_reg, mainloop_params); + // 5) Perform GEMM O = S*V + collective_mma.mmaPV(out_reg, tSr, gV_, out_reg, mainloop_params, is_KV_cache); // Prefetch the next K tile // there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading @@ -346,13 +364,13 @@ class GemmUniversalAttention { Tensor tSr = make_tensor(Shape, Int, Int>{}); clear(tSr); // 3) Perform GEMM S = Q*K - collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_limit - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params); + collective_mma.mmaQK(tSr, gQ, gK(_, _, nblock_new - 1, _), tSr, ceil_div(head_size_qk, QK_BLK_K), mainloop_params, false); // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense - prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock_limit - 1)); + prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock_new - 1)); // mask the elements of each tile where j > i const int item_id = thread_idx % SubgroupSize; - int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + int col_idx = item_id + (nblock_new - 1) * QK_BLK_N; CUTLASS_PRAGMA_UNROLL for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 CUTLASS_PRAGMA_UNROLL @@ -360,16 +378,17 @@ class GemmUniversalAttention { int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if (col_idx > row_idx) - tSr(row, m, n) = -INFINITY; + if (col_idx > row_idx) { + tSr(row, m, n) = -INFINITY; + } } } } CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); + softmax((nblock_new - 1) == 0, tSr, max_reg, sum_reg, out_reg); - collective_mma.mmaPV(out_reg, tSr, gV(_, _ , nblock_limit - 1), out_reg, mainloop_params); + collective_mma.mmaPV(out_reg, tSr, gV(_, _ , nblock_new - 1), out_reg, mainloop_params, false); } auto epilogue_params = CollectiveEpilogue::template get_updated_copies(params.epilogue, params.problem_shape, batch_coord); From 180574ec7a0e8e9e0f4b7adce08c460dfd10ea0d Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Sat, 19 Apr 2025 04:05:08 -0700 Subject: [PATCH 04/20] Correct verify kernel with KV cache --- .../pvc_flash_attn_runner.hpp | 50 ++++++++++--------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 8e813cfda2..aabf555801 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -213,7 +213,7 @@ template struct ExampleRunner { for (int h = 0; h < num_heads; h++) { cutlass::DeviceAllocation block_S; - block_S.reset(seq_len_qo * seq_len_kv); + block_S.reset(seq_len_qo * seq_len_kv_total); ElementK* k_ptr; ElementV* v_ptr; @@ -222,7 +222,7 @@ template struct ExampleRunner { cutlass::DeviceAllocation block_K_concat(head_size_qk * seq_len_kv_total); cutlass::DeviceAllocation block_V_concat(seq_len_kv_total * head_size_vo); - // Concatenate K_cache and K_new + // Concatenate K_cache and K syclcompat::memcpy( block_K_concat.get(), block_K_cache.get() + offset_k_cache, @@ -234,7 +234,7 @@ template struct ExampleRunner { seq_len_kv * head_size_qk ); - // Concatenate V_cache and V_new + // Concatenate V_cache and V syclcompat::memcpy( block_V_concat.get(), block_V_cache.get() + offset_v_cache, @@ -247,8 +247,8 @@ template struct ExampleRunner { ); syclcompat::wait(); - k_ptr = block_K_concat.get() + offset_k_cache; - v_ptr = block_V_concat.get() + offset_v_cache; + k_ptr = block_K_concat.get(); + v_ptr = block_V_concat.get(); } else { k_ptr = block_K.get() + offset_k; @@ -256,12 +256,12 @@ template struct ExampleRunner { } cutlass::TensorRef ref_Q(block_Q.get() + offset_q, LayoutQ::packed({seq_len_qo, head_size_qk})); - cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total })); - cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({ seq_len_kv_total, head_size_vo})); + cutlass::TensorRef ref_K(k_ptr, LayoutK::packed({head_size_qk, seq_len_kv_total})); + cutlass::TensorRef ref_V(v_ptr, LayoutV::packed({seq_len_kv_total, head_size_vo})); cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); cutlass::TensorRef ref_O(block_ref_O.get() + offset_o, LayoutO::packed({seq_len_qo, head_size_vo})); - cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv, head_size_qk}, 1.f, ref_Q, + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, 1.f, ref_Q, cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, 0.f, ref_S, ref_S, ElementAccumulator(0), 1, // batch_count @@ -280,23 +280,25 @@ template struct ExampleRunner { // delete this memory as it is no longer needed block_S.reset(); + // apply causal mask to S if (is_causal) { - // apply mask to S - for (int row = 0; row < seq_len_qo; row++) { - for (int col = 0; col < seq_len_kv; col++) { - if (col > row) - host_S[col + row * seq_len_kv] = -INFINITY; + for (int row = 0; row < seq_len_qo; row++) { + int start_col = use_kv_cache ? seq_len_kv_cache : 0; + for (int col = start_col; col < seq_len_kv_total; col++) { + if (col > row + start_col) { + host_S[col + row * seq_len_kv_total] = -INFINITY; + } + } } - } } // compute max element per row of S std::vector max_vec(seq_len_qo, -INFINITY); for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int max_idx = row; max_vec[max_idx] = host_S[idx++]; - for (int col = 1; col < seq_len_kv; col++, idx++) { + for (int col = 1; col < seq_len_kv_total; col++, idx++) { if (max_vec[max_idx] < host_S[idx]) max_vec[max_idx] = host_S[idx]; } @@ -304,9 +306,9 @@ template struct ExampleRunner { // compute exp of S for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int max_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / sqrt(static_cast((head_size_qk)))); } } @@ -314,16 +316,16 @@ template struct ExampleRunner { // compute sum per row of S std::vector sum_vec(seq_len_qo, ElementOutput{0}); for (int row = 0; row < seq_len_qo; row++) { - int idx = row * seq_len_kv; + int idx = row * seq_len_kv_total; int sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { sum_vec[sum_idx] += host_S[idx]; } // scale each row with the sum to compute softmax - idx = row * seq_len_kv; + idx = row * seq_len_kv_total; sum_idx = row; - for (int col = 0; col < seq_len_kv; col++, idx++) { + for (int col = 0; col < seq_len_kv_total; col++, idx++) { host_S[idx] /= sum_vec[sum_idx]; } } @@ -338,9 +340,9 @@ template struct ExampleRunner { syclcompat::memcpy(block_P.get(), host_P.data(), host_P.size()); syclcompat::wait(); - cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv})); + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); - cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv}, 1.f, ref_P, + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, 1.f, ref_P, cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, 0.f, ref_O, ref_O, ElementAccumulator(0), 1, // batch_count From cac41c2b0ea89bd93f71f8788eafa04e93afe149 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Sat, 19 Apr 2025 07:10:32 -0700 Subject: [PATCH 05/20] Fix causal mask when KV cache --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 001ed9efd6..e341f8d658 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -277,6 +277,7 @@ class GemmUniversalAttention { const int nblock_new = CausalMask ? cute::ceil_div(causal_seq_len, QK_BLK_N) : cute::ceil_div(non_causal_seq_len, QK_BLK_N); int nblock_limit = nblock_cache + nblock_new; + bool is_first_block = (seq_len_kv_cache == 0) && ((nblock_new - 1) == 0); auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, batch_coord); @@ -386,7 +387,7 @@ class GemmUniversalAttention { } CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax((nblock_new - 1) == 0, tSr, max_reg, sum_reg, out_reg); + softmax(is_first_block, tSr, max_reg, sum_reg, out_reg); collective_mma.mmaPV(out_reg, tSr, gV(_, _ , nblock_new - 1), out_reg, mainloop_params, false); } From 9403e6d520150e92808ba14062f497dae803b6e3 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 21 Apr 2025 13:59:47 -0700 Subject: [PATCH 06/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp Co-authored-by: Mehdi Goli --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index e341f8d658..15da976933 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -277,7 +277,6 @@ class GemmUniversalAttention { const int nblock_new = CausalMask ? cute::ceil_div(causal_seq_len, QK_BLK_N) : cute::ceil_div(non_causal_seq_len, QK_BLK_N); int nblock_limit = nblock_cache + nblock_new; - bool is_first_block = (seq_len_kv_cache == 0) && ((nblock_new - 1) == 0); auto mainloop_params = CollectiveMainloop::get_updated_copies(params.mainloop, params.problem_shape, batch_coord); From f413f9cf8be238175371ecad62208386e932dc30 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 21 Apr 2025 13:59:57 -0700 Subject: [PATCH 07/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp Co-authored-by: Mehdi Goli --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 15da976933..a17dbbe682 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -386,7 +386,7 @@ class GemmUniversalAttention { } CollectiveSoftmaxEpilogue softmax(params.softmax); - softmax(is_first_block, tSr, max_reg, sum_reg, out_reg); + softmax((nblock_limit - 1) == 0, tSr, max_reg, sum_reg, out_reg); collective_mma.mmaPV(out_reg, tSr, gV(_, _ , nblock_new - 1), out_reg, mainloop_params, false); } From 9c02255dcdc4d45894cd04e978135d2258b1fd3c Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 21 Apr 2025 14:06:43 -0700 Subject: [PATCH 08/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp Co-authored-by: Mehdi Goli --- .../flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index a17dbbe682..71aefa1a60 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -324,11 +324,10 @@ class GemmUniversalAttention { barrier_arrive(barrier_scope); bool is_KV_cache = (nblock < nblock_cache); - int local_block = is_KV_cache ? nblock : (nblock - nblock_cache); - + // 1) Load KV (performed inside mmaQK) - auto gK_ = is_KV_cache ? gK_cache(_, _, local_block, _) : gK(_, _, local_block, _); - auto gV_ = is_KV_cache ? gV_cache(_, _, local_block) : gV(_, _, local_block); + auto gK_ = is_KV_cache ? gK_cache(_, _, nblock, _) : gK(_, _, nblock - nblock_cache, _); + auto gV_ = is_KV_cache ? gV_cache(_, _, nblock) : gV(_, _, nblock - nblock_cache); // 2) Create Tensor S Tensor tSr = make_tensor(Shape, Int, Int>{}); From 748815c167d9128f5ac6c06edcbb087ad0e87c27 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 21 Apr 2025 15:47:53 -0700 Subject: [PATCH 09/20] Minor update --- .../sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 5d7fdbd3ed..1b1e00489e 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -213,7 +213,7 @@ template struct ExampleRunner { seq_len_kv = get<4>(problem_size); seq_len_kv_cache = get<5>(problem_size); } - int seq_len_kv_total = use_kv_cache ? (seq_len_kv_cache + seq_len_kv) : seq_len_kv; + int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; int kv_group_update=1; for (int h = 0; h < num_heads_q; h++) { @@ -391,7 +391,7 @@ template struct ExampleRunner { template auto initialize_varlen(const ProblemShape& problem_size, const bool VarlenSame = true) { int num_batches = get<0>(problem_size); - int seq_len_kv_cache = get<4>(problem_size); + int seq_len_kv_cache = get<5>(problem_size); // generate Q as --b times // gaussian (--Q, --Q / 2) sampled positive From 72edfc05d4b0c2b5a2ffc4afb7f3b6c6236c68b6 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Tue, 22 Apr 2025 16:04:32 +0800 Subject: [PATCH 10/20] Update flops, gbps calculation --- .../sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 1b1e00489e..15a8979cde 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -602,11 +602,11 @@ template struct ExampleRunner { syclcompat::wait(); double cute_time = timer.seconds() / options.iterations; - double flops_qk = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.seq_len_kv * options.head_size_qk; - double flops_pv = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.head_size_vo * options.seq_len_kv; + double flops_qk = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * (options.seq_len_kv + options.seq_len_kv_cache) * options.head_size_qk; + double flops_pv = 2.0 * options.batch * options.num_heads_q * options.seq_len_qo * options.head_size_vo * (options.seq_len_kv + options.seq_len_kv_cache); double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; - double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options.seq_len_qo * options.head_size_qk + options.seq_len_kv * options.head_size_qk); - double gbps_pv = 2.0 * options.batch * options.num_heads_q * (options.seq_len_kv * options.seq_len_qo + options.seq_len_qo * options.head_size_vo); + double gbps_qk = 2.0 * options.batch * options.num_heads_q * (options.seq_len_qo * options.head_size_qk + (options.seq_len_kv + options.seq_len_kv_cache) * options.head_size_qk); + double gbps_pv = 2.0 * options.batch * options.num_heads_q * ((options.seq_len_kv + options.seq_len_kv_cache) * options.seq_len_qo + options.seq_len_qo * options.head_size_vo); double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo << "\tSeq Length KV: " << options.seq_len_kv << "\tSeq Length KV Cache: " << options.seq_len_kv_cache From 6735cfd5aa1cae734caf658b91a8a55f29087354 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Tue, 22 Apr 2025 01:09:39 -0700 Subject: [PATCH 11/20] Fix verify when num_heads_kv != num_heads_q --- .../sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 15a8979cde..471c674041 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -371,10 +371,10 @@ template struct ExampleRunner { if(kv_group_update % q_group_size==0) { offset_k += seq_len_kv * head_size_qk; offset_v += seq_len_kv * head_size_vo; + offset_k_cache += seq_len_kv_cache * head_size_qk; + offset_v_cache += seq_len_kv_cache * head_size_vo; } kv_group_update++; - offset_k_cache += seq_len_kv_cache * head_size_qk; - offset_v_cache += seq_len_kv_cache * head_size_vo; offset_o += seq_len_qo * head_size_vo; } } From c9d11eae83e6946286744857fa90e026baf4378e Mon Sep 17 00:00:00 2001 From: mehdi-goli Date: Tue, 22 Apr 2025 12:59:17 +0100 Subject: [PATCH 12/20] Fixing the index for launch --- applications/flash_attention_v2/kernel/tile_scheduler.hpp | 8 ++++---- .../sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/applications/flash_attention_v2/kernel/tile_scheduler.hpp b/applications/flash_attention_v2/kernel/tile_scheduler.hpp index c02f87d173..af8831fd89 100644 --- a/applications/flash_attention_v2/kernel/tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/tile_scheduler.hpp @@ -59,8 +59,8 @@ struct XeFlashIndividualTileScheduler { ProblemSize const& problem_size, KernelHardwareInfo hw_info, TileShape const& tile_shape) { using namespace cute; - // problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] - dim3 grid(size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))), + // problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache head_size_qk, head_size_vo] + dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), size(shape<0>(problem_size) * shape<1>(problem_size))); return Params{ grid, {shape<1>(problem_size)} }; @@ -127,8 +127,8 @@ struct XeFlashPersistentTileScheduler { CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); hw_info.sm_count = sm_count; - // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] - int num_head_size_blocks = size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))); + // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] + int num_head_size_blocks = size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))); int num_seq_len_blocks = size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); int num_blocks = num_seq_len_blocks * num_head_size_blocks * size(shape<0>(problem_size) * shape<1>(problem_size)); diff --git a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp index 471c674041..ecbd310edd 100644 --- a/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp +++ b/examples/sycl/06_pvc_flash_attention/pvc_flash_attn_runner.hpp @@ -442,7 +442,7 @@ template struct ExampleRunner { get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv}; get<5>(problem_size_for_launch) = get<5>(problem_size); get<6>(problem_size_for_launch) = get<6>(problem_size); - get<7>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); get<0>(problem_size_for_launch) = get<0>(problem_size); get<1>(problem_size_for_launch) = get<1>(problem_size); get<2>(problem_size_for_launch) = get<2>(problem_size); From e7505c8f99384cd966ce9e6c302f78c6c8696b1c Mon Sep 17 00:00:00 2001 From: mehdi-goli Date: Tue, 22 Apr 2025 18:00:34 +0100 Subject: [PATCH 13/20] Adding prefetch to the extend version. Fixing the stride for the variable lengh --- .../collective/xe_flash_attn_mma.hpp | 21 +++++++++++-------- .../kernel/xe_flash_attn_gemm.hpp | 15 +++++++++---- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index ba34d13137..05321cd4fb 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -154,7 +154,8 @@ struct CollectiveMmaAttention, ProblemShapeType_, using atom_load_V = Copy_Atom; using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout{}, val_layout_load_V{})); - + using TensorK = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideK{})); //(m, k) + using TensorV = decltype(make_tensor(make_gmem_ptr(static_cast(nullptr)), make_shape(0,0,0), StrideV{})); //(n, k) // Host side kernel arguments struct Arguments { ElementQ const *ptr_Q; @@ -175,6 +176,8 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_V gmem_tiled_copy_v; XE_Copy_K gmem_tiled_copy_k_cache; XE_Copy_V gmem_tiled_copy_v_cache; + TensorK gmem_tensor_k_cache; + TensorV gmem_tensor_v_cache; }; // @@ -201,7 +204,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; } template @@ -335,19 +338,19 @@ struct CollectiveMmaAttention, ProblemShapeType_, int offset_v_cache = num_heads_kv * head_size_vo * seq_len_kv_cache; auto q_traits = static_cast(params.gmem_tiled_copy_q); - ElementQ* q_ptr = (ElementQ*)q_traits.base_ptr; + const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr; auto k_traits = static_cast(params.gmem_tiled_copy_k); - ElementK* k_ptr = (ElementK*)k_traits.base_ptr; + const ElementK* k_ptr = (const ElementK*)k_traits.base_ptr; auto v_traits = static_cast(params.gmem_tiled_copy_v); - ElementV* v_ptr = (ElementV*)v_traits.base_ptr; + const ElementV* v_ptr = (const ElementV*)v_traits.base_ptr; auto k_traits_cache = static_cast(params.gmem_tiled_copy_k_cache); - ElementK* k_cache_ptr = (ElementK*)k_traits_cache.base_ptr; + const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr; auto v_traits_cache = static_cast(params.gmem_tiled_copy_v_cache); - ElementV* v_cache_ptr = (ElementV*)v_traits_cache.base_ptr; + const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr; auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk, num_heads_q); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); @@ -367,7 +370,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q)); auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k)); auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v)); - auto tensorK_cache = make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, shape_k_cache)); + auto tensorK_cache = make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache)); auto tensorV_cache = make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; @@ -376,7 +379,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; } } }; diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index c9300cb181..5ee242eacf 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -295,16 +295,21 @@ class GemmUniversalAttention { auto pQgQ = thr_prefetch_Q.partition_S(gQ); auto pKgK = thr_prefetch_K.partition_S(gK); auto pVgV = thr_prefetch_V.partition_S(gV); - + // assuming the copy function is the same otherwise this need to have its own tile_prefetch + auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); + using Pefetch_K = decltype(tiled_prefetch_k); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<3>(pQgQ); i++) { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } + Pefetch_K prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: Pefetch_K{tiled_prefetch_k.with(mainloop_params.gmem_tensor_k_cache)}; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < DispatchPolicy::Stages; i++) { + // The headsize for both cached and non-cached version is the same CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<4>(pKgK); j++) { - prefetch(tiled_prefetch_k, pKgK(_, _, _ , i, j)); + prefetch(prefetch_K, pKgK(_, _, _ , i, j)); } } @@ -344,7 +349,8 @@ class GemmUniversalAttention { // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense - prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock)); + (is_KV_cache) ? prefetch(tiled_prefetch_v.with(mainloop_params.gmem_tensor_v_cache), pVgV_cache(_, _, _ , nblock)) + : prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock - nblock_cache)); // 4) Fused softmax CollectiveSoftmaxEpilogue softmax(params.softmax); @@ -355,9 +361,10 @@ class GemmUniversalAttention { // Prefetch the next K tile // there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading + prefetch_K = (nblock + DispatchPolicy::Stages < nblock_cache) ? Pefetch_K{tiled_prefetch_k.with(mainloop_params.gmem_tensor_k_cache)}: tiled_prefetch_k; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<4>(pKgK); j++) { - prefetch(tiled_prefetch_k, pKgK(_, _, _, nblock + DispatchPolicy::Stages, j)); + prefetch(prefetch_K, pKgK_cache(_, _, _, nblock + DispatchPolicy::Stages, j)); } barrier_wait(barrier_scope); } From e66634249b6caf1bf8faea659273b9a787f932de Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 24 Apr 2025 03:10:01 -0700 Subject: [PATCH 14/20] Update applications/flash_attention_v2/kernel/tile_scheduler.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tadej Ciglarič --- applications/flash_attention_v2/kernel/tile_scheduler.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/tile_scheduler.hpp b/applications/flash_attention_v2/kernel/tile_scheduler.hpp index af8831fd89..4d8dd2655a 100644 --- a/applications/flash_attention_v2/kernel/tile_scheduler.hpp +++ b/applications/flash_attention_v2/kernel/tile_scheduler.hpp @@ -59,7 +59,7 @@ struct XeFlashIndividualTileScheduler { ProblemSize const& problem_size, KernelHardwareInfo hw_info, TileShape const& tile_shape) { using namespace cute; - // problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache head_size_qk, head_size_vo] + // problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), size(shape<0>(problem_size) * shape<1>(problem_size))); From 31e55800463c0ef28accf7ca09d9fa8e6413ea7f Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 24 Apr 2025 03:11:01 -0700 Subject: [PATCH 15/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tadej Ciglarič --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 5ee242eacf..fae0b27d42 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -334,7 +334,7 @@ class GemmUniversalAttention { for (int nblock = 0; nblock < nblock_limit - static_cast(CausalMask); nblock++) { barrier_arrive(barrier_scope); - bool is_KV_cache = (nblock < nblock_cache); + bool is_KV_cache = nblock < nblock_cache; // 1) Load KV (performed inside mmaQK) auto gK_ = is_KV_cache ? gK_cache(_, _, nblock, _) : gK(_, _, nblock - nblock_cache, _); From 21056aa6caeaeba80e3bcecdfa88bf47647685c2 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 24 Apr 2025 03:11:29 -0700 Subject: [PATCH 16/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tadej Ciglarič --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index fae0b27d42..4cc0aafcdc 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -349,7 +349,7 @@ class GemmUniversalAttention { // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense - (is_KV_cache) ? prefetch(tiled_prefetch_v.with(mainloop_params.gmem_tensor_v_cache), pVgV_cache(_, _, _ , nblock)) + is_KV_cache ? prefetch(tiled_prefetch_v.with(mainloop_params.gmem_tensor_v_cache), pVgV_cache(_, _, _ , nblock)) : prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock - nblock_cache)); // 4) Fused softmax From 31b9aa981a1ae9fb5e2b3eff80ce6c62eccc7bdd Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 24 Apr 2025 03:11:46 -0700 Subject: [PATCH 17/20] Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tadej Ciglarič --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 4cc0aafcdc..4206796ae0 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -390,7 +390,7 @@ class GemmUniversalAttention { int row_idx = m * Vec + seq_coord; CUTLASS_PRAGMA_UNROLL for (int row = 0; row < Vec; row++, row_idx++) { // 8 - if ((col_idx - full_tile_offset) > (row_idx - discard_seq_coord)) { + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { tSr(row, m, n) = -INFINITY; } } From 83bf043968904997ae88b3c3907003f14d530931 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Thu, 24 Apr 2025 03:11:53 -0700 Subject: [PATCH 18/20] Update applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tadej Ciglarič --- .../flash_attention_v2/collective/xe_flash_attn_mma.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index 05321cd4fb..0491d4f137 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -204,7 +204,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; } template From 5620c380bbf17a0f3d0297d5e0241d7558566ed0 Mon Sep 17 00:00:00 2001 From: mehdi-goli Date: Fri, 25 Apr 2025 20:15:21 +0100 Subject: [PATCH 19/20] Applyig the comments --- .../collective/xe_flash_attn_mma.hpp | 6 ++---- .../kernel/xe_flash_attn_gemm.hpp | 19 ++++++++++++------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp index 0491d4f137..a653ac5b02 100644 --- a/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp +++ b/applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp @@ -176,8 +176,6 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_V gmem_tiled_copy_v; XE_Copy_K gmem_tiled_copy_k_cache; XE_Copy_V gmem_tiled_copy_v_cache; - TensorK gmem_tensor_k_cache; - TensorV gmem_tensor_v_cache; }; // @@ -204,7 +202,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; } template @@ -379,7 +377,7 @@ struct CollectiveMmaAttention, ProblemShapeType_, XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; - return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache, tensorK_cache, tensorV_cache}; + return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache}; } } }; diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 4206796ae0..105e458e35 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -289,6 +289,8 @@ class GemmUniversalAttention { auto tiled_prefetch_q = cute::prefetch_selector, Int>, Num_SGs>(mainloop_params.gmem_tiled_copy_q); auto tiled_prefetch_k = cute::prefetch_selector, Int>, Num_SGs>(mainloop_params.gmem_tiled_copy_k); auto tiled_prefetch_v = cute::prefetch_selector, Int>, Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto tiled_prefetch_k_cache = cute::prefetch_selector, Int>, Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute::prefetch_selector, Int>, Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); @@ -298,18 +300,18 @@ class GemmUniversalAttention { // assuming the copy function is the same otherwise this need to have its own tile_prefetch auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); - using Pefetch_K = decltype(tiled_prefetch_k); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<3>(pQgQ); i++) { prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); } - Pefetch_K prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: Pefetch_K{tiled_prefetch_k.with(mainloop_params.gmem_tensor_k_cache)}; + auto& prefetch_K = (seq_len_kv_cache == 0) ? tiled_prefetch_k: tiled_prefetch_k_cache; + auto& pKgK1_ = (seq_len_kv_cache == 0) ? pKgK: pKgK_cache; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < DispatchPolicy::Stages; i++) { // The headsize for both cached and non-cached version is the same CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<4>(pKgK); j++) { - prefetch(prefetch_K, pKgK(_, _, _ , i, j)); + for (int j = 0; j < size<4>(pKgK1_); j++) { + prefetch(prefetch_K, pKgK1_(_, _, _ , i, j)); } } @@ -349,7 +351,7 @@ class GemmUniversalAttention { // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense - is_KV_cache ? prefetch(tiled_prefetch_v.with(mainloop_params.gmem_tensor_v_cache), pVgV_cache(_, _, _ , nblock)) + is_KV_cache ? prefetch(tiled_prefetch_v_cache, pVgV_cache(_, _, _ , nblock)) : prefetch(tiled_prefetch_v, pVgV(_, _, _ , nblock - nblock_cache)); // 4) Fused softmax @@ -361,10 +363,13 @@ class GemmUniversalAttention { // Prefetch the next K tile // there is no need to gaurd it with if statememt as prefetch will ignore out of bound reading - prefetch_K = (nblock + DispatchPolicy::Stages < nblock_cache) ? Pefetch_K{tiled_prefetch_k.with(mainloop_params.gmem_tensor_k_cache)}: tiled_prefetch_k; + + bool sel_prefetch_k = (nblock + DispatchPolicy::Stages) < nblock_cache; + auto& prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k; + auto& pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; CUTLASS_PRAGMA_UNROLL for (int j = 0; j < size<4>(pKgK); j++) { - prefetch(prefetch_K, pKgK_cache(_, _, _, nblock + DispatchPolicy::Stages, j)); + prefetch(prefetch_k_selector, pKgK_(_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j)); } barrier_wait(barrier_scope); } From e5fbdbb1e79ea86690c0386f9aaf041d99c34d6b Mon Sep 17 00:00:00 2001 From: Mehdi Goli Date: Fri, 25 Apr 2025 22:18:14 +0100 Subject: [PATCH 20/20] Update xe_flash_attn_gemm.hpp --- applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp index 105e458e35..ae95b67f02 100644 --- a/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp +++ b/applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp @@ -368,7 +368,7 @@ class GemmUniversalAttention { auto& prefetch_k_selector = sel_prefetch_k ? tiled_prefetch_k_cache: tiled_prefetch_k; auto& pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < size<4>(pKgK); j++) { + for (int j = 0; j < size<4>(pKgK_); j++) { prefetch(prefetch_k_selector, pKgK_(_, _, _, (nblock + DispatchPolicy::Stages) - (!sel_prefetch_k) * nblock_cache , j)); } barrier_wait(barrier_scope);