Skip to content

Commit

Permalink
Update for fmha_fwd qs_ks_vs pipeline (#1810)
Browse files Browse the repository at this point in the history
* Update for fmha_fwd qs_ks_vs pipeline

* Remove _builtin_amdgcn_sched_barrier(0)

* Move p_compute to p converting earlier for trying to increase vgprs re-using

* Enable GetQKBlockGemm to use WarpGemm-16x16x16 for QLoadOnce==false situation

* Re-add __builtin_amdgcn_sched_barrier(0)

---------

Co-authored-by: Po Yen Chen <[email protected]>
  • Loading branch information
qianfengz and poyenc authored Jan 13, 2025
1 parent fd46a01 commit 3d50f57
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 27 deletions.
21 changes: 11 additions & 10 deletions include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,6 @@ struct BlockFmhaPipelineQSKSVS
return Policy::template GetSmemSize<Problem>();
}

CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
{
return Policy::template GetSmemSizeQ<Problem>();
}

template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
Expand Down Expand Up @@ -328,8 +323,7 @@ struct BlockFmhaPipelineQSKSVS
});
}

const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
{ // tail
{ // tail
block_sync_lds();
gemm_0(s_acc, q_lds_window, k_lds_window);
block_sync_lds();
Expand All @@ -341,6 +335,10 @@ struct BlockFmhaPipelineQSKSVS
gemm_0(s_acc, q_lds_window, k_lds_window);
}

__builtin_amdgcn_sched_barrier(0);
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
__builtin_amdgcn_sched_barrier(0);

// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
Expand Down Expand Up @@ -462,6 +460,12 @@ struct BlockFmhaPipelineQSKSVS
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})

block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});

const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));

__builtin_amdgcn_sched_barrier(0);

// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
Expand Down Expand Up @@ -509,9 +513,6 @@ struct BlockFmhaPipelineQSKSVS
}
move_tile_window(v_dram_window, {0, kK1});

const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));

// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,33 @@
namespace ck_tile {

// This pipeline is qkv all located in LDS
using BlockFmhaPipelineQSKSVSDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>;
struct BlockFmhaPipelineQSKSVSDefaultPolicy
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ false,
/* AsyncCopyK = */ false,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 1,
/* NumPrefetchV = */ 1>
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
{
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::KDataType);
} // namespace ck_tile

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
{
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
sizeof(typename Problem::VDataType);
}

template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return max(GetSmemSizeQ<Problem>() + GetSmemSizeK<Problem>(), GetSmemSizeV<Problem>()) +
GetSmemSizeDropout<Problem>();
}
};

} // namespace ck_tile
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;

constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);

// this should align with MakeQDramTileDistribution()
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
return min(ElemPerThread, MaxVectorSize);
}

template <typename Problem>
Expand All @@ -156,19 +164,25 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
using QDataType = remove_cvref_t<typename Problem::QDataType>;

constexpr index_t kBlockSize = Problem::kBlockSize;

constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;

constexpr index_t K1 = 16 / sizeof(QDataType); // use dwordx4. TODO: change this
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);

constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);

constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);

return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
Expand Down Expand Up @@ -215,18 +229,31 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;

constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);

constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaF16F16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
else // WarpGemmM == 4
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
Expand Down

0 comments on commit 3d50f57

Please sign in to comment.