From 3d50f57f4362afc9a69e39858ea3bda9b0fb5159 Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Mon, 13 Jan 2025 12:43:05 +0800 Subject: [PATCH] Update for fmha_fwd qs_ks_vs pipeline (#1810) * Update for fmha_fwd qs_ks_vs pipeline * Remove _builtin_amdgcn_sched_barrier(0) * Move p_compute to p converting earlier for trying to increase vgprs re-using * Enable GetQKBlockGemm to use WarpGemm-16x16x16 for QLoadOnce==false situation * Re-add __builtin_amdgcn_sched_barrier(0) --------- Co-authored-by: Po Yen Chen --- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 21 ++++---- ..._fmha_pipeline_qs_ks_vs_default_policy.hpp | 34 ++++++++++--- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 49 ++++++++++++++----- 3 files changed, 77 insertions(+), 27 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index b79889bc10..c2223fcee6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS return Policy::template GetSmemSize(); } - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() - { - return Policy::template GetSmemSizeQ(); - } - template {}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + + const auto p = + cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + + __builtin_amdgcn_sched_barrier(0); + // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { @@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS } move_tile_window(v_dram_window, {0, kK1}); - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); - // STAGE 3, KV gemm if constexpr(k1_loops > 1) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp index b312fd07ae..ff8299b4ff 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp @@ -9,11 +9,33 @@ namespace ck_tile { // This pipeline is qkv all located in LDS -using BlockFmhaPipelineQSKSVSDefaultPolicy = - BlockFmhaPipelineQXKSVSCustomPolicy; +struct BlockFmhaPipelineQSKSVSDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::KDataType); + } // namespace ck_tile + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::VDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return max(GetSmemSizeQ() + GetSmemSizeK(), GetSmemSizeV()) + + GetSmemSizeDropout(); + } +}; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 07164ec85c..3db461e971 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() { - using QDataType = remove_cvref_t; - return 16 / sizeof(QDataType); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); + + // this should align with MakeQDramTileDistribution() + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + return min(ElemPerThread, MaxVectorSize); } template @@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; - constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - constexpr index_t M1 = kBlockSize / get_warp_size(); - constexpr index_t M0 = kMPerBlock / (M2 * M1); + constexpr index_t MaxVectorSize = 16 / sizeof(QDataType); + + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, + sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, @@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy typename Problem::BlockFmhaShape::Gemm0BlockWarps, typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + constexpr auto warp_gemm = []() { if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; } else if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) { - return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; } else if constexpr(std::is_same_v && std::is_same_v &&