Skip to content

Commit

Permalink
kernel: added stage support for MLA kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Feb 16, 2025
1 parent 29646c3 commit 562a56a
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 96 deletions.
182 changes: 104 additions & 78 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mask.h"
#include "mla_tile.h"
#include "online_softmax.cuh"
#include "ptx.cuh"

namespace llm {

Expand All @@ -31,14 +30,15 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
constexpr int kBlockN = Traits::kBlockN;
constexpr int kBlockK = Traits::kBlockK;
constexpr int kHeadDim = Traits::kHeadDim;
constexpr int kSteps = Traits::kSteps;
constexpr int kStages = Traits::kStages;
constexpr int kRopeHeadDim = Traits::kRopeHeadDim;
constexpr int kRowsPerMMA = Traits::kRowsPerMMA;

using _BLK_M = Int<kBlockM>;
using _BLK_N = Int<kBlockN>;
using _BLK_K = Int<kBlockK>;
using _STAGES = Int<kStages>;
using _STEPS = Int<kSteps>;
using _HEAD_DIM = Int<kHeadDim>;
using _ROPE_HEAD_DIM = Int<kRopeHeadDim>;

Expand Down Expand Up @@ -99,12 +99,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
}

// Gmem
// (BLK_M, BLK_K, STAGES)
// (BLK_M, BLK_K, STEPS)
Tensor gQ =
local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
Tensor gO =
local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
// (BLK_N, BLK_K, n, STAGES)
// (BLK_N, BLK_K, n, STEPS)
Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _));

// (BLK_M, ROPE_HEAD_DIM)
Expand All @@ -123,24 +123,24 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{});
float* row_sync_smem = (float*)(k_rope_smem + cosize(SmemLayoutKRope{}));

// (BLK_M, BLK_K, STAGES), k-major
// (BLK_M, BLK_K, STEPS), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
// (BLK_N, BLK_K, STAGES), k-major
// (BLK_N, BLK_K, STEPS, STAGES), k-major
Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{});

// (BLK_M, BLK_N), k-major
Tensor sP = make_tensor(make_smem_ptr(p_smem), SmemLayoutP{});

// (BLK_M, ROPE_HEAD_DIM), k-major
Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{});
// (BLK_N, ROPE_HEAD_DIM), k-major
// (BLK_N, ROPE_HEAD_DIM, STAGES), k-major
Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{});

// Tensor for V^t; used in GEMM-II.
// (BLK_K, BLK_N, STAGES)
// (BLK_K, BLK_N, STEPS, STAGES)
Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{});

// (BLK_M, BLK_K, STAGES), reuse smem
// (BLK_M, BLK_K, STEPS), reuse smem
Tensor sO = make_tensor(make_smem_ptr(q_smem), SmemLayoutO{});

// (BLK_M, 2)
Expand Down Expand Up @@ -180,10 +180,10 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for q
GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx);
auto produce_q = [&](int stage) {
// gQ/sQ: (BLK_M, BLK_K, STAGES)
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, stage));
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, stage));
auto produce_q = [&](int step) {
// gQ/sQ: (BLK_M, BLK_K, STEPS)
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step));
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step));
cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ);
};

Expand All @@ -199,65 +199,56 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for kv
GmemTiledCopyKV gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
auto produce_kv = [&](int ni, int stage) {
// gKV: (BLK_N, BLK_K, n, STAGES)
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, stage));
// sK: (BLK_N, BLK_K, STAGES)
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, stage));
auto produce_kv = [&](int ni, int step, int stage) {
// gKV: (BLK_N, BLK_K, n, STEPS)
// sK: (BLK_N, BLK_K, STEPS, STAGES)
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step));
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage));
cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV);
};

// g2s tiled copy for k_rope
GmemTiledCopyKRope gmem_tiled_copy_K_rope;
auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx);
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope);
auto produce_k_rope = [&](int ni) {
auto produce_k_rope = [&](int ni, int stage) {
// gK_rope: (BLK_N, ROPE_HEAD_DIM, n)
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES)
auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni));
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage));
cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope);
};

// GEMM-I: S = [email protected]
TiledMma_QK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_slice(tidx);
// sQ/sK: (BLK_M, BLK_K, STAGES)
// sQ: (BLK_M, BLK_K, STEPS)
auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{}));
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}));
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope);
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope);
// sK: (BLK_N, BLK_K, STEPS, STAGES)
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}, _0{}));

// s2r tiled copy for q/q_rope
SmemTiledCopyQ smem_tiled_copy_Q;
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx);
// (CPY, CPY_M, CPY_K, STAGES)
// (CPY, CPY_M, CPY_K, STEPS)
auto tCsQ = smem_thr_copy_Q.partition_S(sQ);
// (CPY, CPY_M, CPY_K)
auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ);

// (CPY, CPY_M, CPY_K)
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope);
// (CPY, CPY_M, CPY_K)
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);

// s2r tiled copy for k/k_rope
SmemTiledCopyK smem_tiled_copy_K;
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
// (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsK = smem_thr_copy_K.partition_S(sK);
// (CPY, CPY_N, CPY_K)
auto tCrK = smem_thr_copy_K.retile_D(tSrK);

// (CPY, CPY_N, CPY_K)
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope);
// (CPY, CPY_N, CPY_K)
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope);

// S = [email protected]
// tSrS: (MMA,MMA_M,MMA_N)
auto compute_qk = [&](auto& tSrS, int s) {
// (CPY, CPY_M, CPY_K, STAGES)
auto tCsQ_s = tCsQ(_, _, _, s);
auto tCsK_s = tCsK(_, _, _, s);
auto compute_qk = [&](auto& tSrS, int step, int stage) {
// tCsQ: (CPY, CPY_M, CPY_K, STEPS)
auto tCsQ_s = tCsQ(_, _, _, step);
// TCsK: (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsK_s = tCsK(_, _, _, step, stage);
// prefetch kv
cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{}));
cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{}));
Expand All @@ -274,9 +265,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
}
};

auto compute_qk_rope = [&](auto& tSrS) {
// sQ_rope: (BLK_N, ROPE_HEAD_DIM)
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope);
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES)
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{}));
// (CPY, CPY_M, CPY_K)
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope);
// (CPY, CPY_M, CPY_K)
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);
// (CPY, CPY_N, CPY_K, STAGES)
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope);
// (CPY, CPY_N, CPY_K)
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope);
auto compute_qk_rope = [&](auto& tSrS, int stage) {
auto tCsK_rope_s = tCsK_rope(_, _, _, stage);
cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{}));
cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{}));
cute::copy(
smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{}));

CUTE_UNROLL
for (int k = 0; k < size<2>(tCsQ_rope); ++k) {
Expand All @@ -286,7 +291,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
tCsQ_rope(_, _, next_k),
tCrQ_rope(_, _, next_k));
cute::copy(smem_tiled_copy_K,
tCsK_rope(_, _, next_k),
tCsK_rope_s(_, _, next_k),
tCrK_rope(_, _, next_k));
}
cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS);
Expand All @@ -298,8 +303,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto thr_mma_pv = tiled_mma_pv.get_slice(tidx);
// sP: (BLK_M, BLK_N)
auto tOrP = thr_mma_pv.partition_fragment_A(sP);
// sVt: (BLK_K, BLK_N, STAGES)
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}));
// sVt: (BLK_K, BLK_N, STEPS, STAGES)
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{}));

// s2r tiled copy for p
SmemTiledCopyP smem_tiled_copy_P;
Expand All @@ -312,16 +317,17 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// s2r tiled copy for vt
SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
// (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsVt = smem_thr_copy_Vt.partition_S(sVt);
// (CPY, CPY_N, CPY_K)
auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt);

// O = P*V = softmax(S)*V
// tOrO: (MMA,MMA_M,MMA_K,STAGES)
auto compute_pv = [&](auto& tOrO, int s) {
auto tOrO_s = tOrO(_, _, _, s);
auto tCsVt_s = tCsVt(_, _, _, s);
// tOrO: (MMA,MMA_M,MMA_K,STEPS)
auto compute_pv = [&](auto& tOrO, int step, int stage) {
auto tOrO_s = tOrO(_, _, _, step);
auto tCsVt_s = tCsVt(_, _, _, step, stage);

cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{}));
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{}));

Expand Down Expand Up @@ -350,16 +356,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(smem_tiled_copy_S, tCrS, tCsS);
};

// tOrO: (MMA,MMA_M,MMA_K,STAGES)
// tOrO: (MMA,MMA_M,MMA_K,STEPS)
auto epilogue = [&](const auto& tOrO) {
// write output to gmem
// 1. copy output from reg to smem (reuse sQ)
SmemTiledCopyO smem_tiled_copy_O;
auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx);
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
auto tOrO_s = tOrO(_, _, _, s);
auto sO_s = sO(_, _, s);
for (int step = 0; step < kSteps; ++step) {
auto tOrO_s = tOrO(_, _, _, step);
auto sO_s = sO(_, _, step);

// cast Accumulator to Element type
auto tOrO_ = make_tensor_like<DType>(tOrO_s);
Expand All @@ -381,29 +387,32 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(gmem_tiled_copy_O, tCsO, tCgO);
};

// output accumulator: (MMA,MMA_M,MMA_K,STAGES)
// output accumulator: (MMA,MMA_M,MMA_K,STEPS)
auto tOrO =
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{});
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STEPS>{});
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout()));
clear(tOrO);

const int n_block_min = 0;
const int n_block_max = cute::ceil_div(size<0>(KV), kBlockN);

// ############### Prologue ###############
// produce q_rope: [] => [q_rope, q...]
// produce q_rope/q: [] => [q_rope, q...]
produce_q_rope();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
produce_q(s);
for (int step = 0; step < kSteps; ++step) {
produce_q(step);
}
// produce k_rope: [q_rope, q...] => [q_rope, q..., k_rope, kv...]
produce_k_rope(0);
cp_async_fence();
// produce k_rope/kv: [q_rope, q...] => [q_rope, q..., k_rope, kv...]
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
produce_kv(0, s);
for (int stage = 0; stage < kStages; ++stage) {
produce_k_rope(stage, stage);
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
produce_kv(stage, step, stage);
cp_async_fence();
}
}

// ############### Mainloop ###############
Expand All @@ -421,25 +430,29 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
Softmax softmax(params.sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);

constexpr int kWait = kStages * (kSteps + 1) - 1;
int stage = 0;

CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
clear(tSrS);

// wait key, queue: [q, q_rope, kv, k_rope] => []
cp_async_wait<kStages>();
// wait queue: [q_rope, q..., (k_rope, kv...), (k_rope, kv...)]
// => [kv..., (k_rope, kv...)]
cp_async_wait<kWait>();
__syncthreads();

// 1> S = Q_rope@K_rope.T
compute_qk_rope(tSrS);
compute_qk_rope(tSrS, stage);
cp_async_fence();

// 2> S += [email protected]
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
cp_async_wait<kStages>();
for (int step = 0; step < kSteps; ++step) {
cp_async_wait<kWait>();
__syncthreads();

compute_qk(tSrS, s);
compute_qk(tSrS, step, stage);
cp_async_fence();
}

Expand All @@ -452,24 +465,37 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
__syncthreads();

// 3> O = softmax(S)*V
const auto next_ni = ni + 1;
const auto next_ni = ni + kStages;
if (next_ni != n_block_max) {
produce_k_rope(next_ni);
produce_k_rope(next_ni, stage);
cp_async_fence();

CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
for (int step = 0; step < kSteps; ++step) {
compute_pv(tOrO, step, stage);

__syncthreads();

produce_kv(next_ni, s);
produce_kv(next_ni, step, stage);
cp_async_fence();
}
} else {
cp_async_fence();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
for (int step = 0; step < kSteps; ++step) {
compute_pv(tOrO, step, stage);
cp_async_fence();
}
}

// move to next stage
if constexpr (kStages == 1) {
// do nothing
} else if constexpr (kStages == 2) {
stage = stage ^ 1;
} else {
stage = (stage + 1) % kStages;
}
}

// ############### Epilogue ###############
Expand Down
Loading

0 comments on commit 562a56a

Please sign in to comment.