Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernel: added stage support for MLA kernel #410

Merged
merged 3 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 117 additions & 78 deletions src/kernels/attention/mla_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mask.h"
#include "mla_tile.h"
#include "online_softmax.cuh"
#include "ptx.cuh"

namespace llm {

Expand All @@ -31,14 +30,15 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
constexpr int kBlockN = Traits::kBlockN;
constexpr int kBlockK = Traits::kBlockK;
constexpr int kHeadDim = Traits::kHeadDim;
constexpr int kSteps = Traits::kSteps;
constexpr int kStages = Traits::kStages;
constexpr int kRopeHeadDim = Traits::kRopeHeadDim;
constexpr int kRowsPerMMA = Traits::kRowsPerMMA;

using _BLK_M = Int<kBlockM>;
using _BLK_N = Int<kBlockN>;
using _BLK_K = Int<kBlockK>;
using _STAGES = Int<kStages>;
using _STEPS = Int<kSteps>;
using _HEAD_DIM = Int<kHeadDim>;
using _ROPE_HEAD_DIM = Int<kRopeHeadDim>;

Expand Down Expand Up @@ -99,12 +99,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
}

// Gmem
// (BLK_M, BLK_K, STAGES)
// (BLK_M, BLK_K, STEPS)
Tensor gQ =
local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
Tensor gO =
local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _));
// (BLK_N, BLK_K, n, STAGES)
// (BLK_N, BLK_K, n, STEPS)
Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _));

// (BLK_M, ROPE_HEAD_DIM)
Expand All @@ -123,24 +123,24 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{});
float* row_sync_smem = (float*)(k_rope_smem + cosize(SmemLayoutKRope{}));

// (BLK_M, BLK_K, STAGES), k-major
// (BLK_M, BLK_K, STEPS), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
// (BLK_N, BLK_K, STAGES), k-major
// (BLK_N, BLK_K, STEPS, STAGES), k-major
Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{});

// (BLK_M, BLK_N), k-major
Tensor sP = make_tensor(make_smem_ptr(p_smem), SmemLayoutP{});

// (BLK_M, ROPE_HEAD_DIM), k-major
Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{});
// (BLK_N, ROPE_HEAD_DIM), k-major
// (BLK_N, ROPE_HEAD_DIM, STAGES), k-major
Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{});

// Tensor for V^t; used in GEMM-II.
// (BLK_K, BLK_N, STAGES)
// (BLK_K, BLK_N, STEPS, STAGES)
Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{});

// (BLK_M, BLK_K, STAGES), reuse smem
// (BLK_M, BLK_K, STEPS), reuse smem
Tensor sO = make_tensor(make_smem_ptr(q_smem), SmemLayoutO{});

// (BLK_M, 2)
Expand Down Expand Up @@ -180,10 +180,10 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for q
GmemTiledCopyQ gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx);
auto produce_q = [&](int stage) {
// gQ/sQ: (BLK_M, BLK_K, STAGES)
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, stage));
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, stage));
auto produce_q = [&](int step) {
// gQ/sQ: (BLK_M, BLK_K, STEPS)
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step));
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step));
cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ);
};

Expand All @@ -199,65 +199,56 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// g2s tiled copy for kv
GmemTiledCopyKV gmem_tiled_copy_KV;
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
auto produce_kv = [&](int ni, int stage) {
// gKV: (BLK_N, BLK_K, n, STAGES)
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, stage));
// sK: (BLK_N, BLK_K, STAGES)
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, stage));
auto produce_kv = [&](int ni, int step, int stage) {
// gKV: (BLK_N, BLK_K, n, STEPS)
// sK: (BLK_N, BLK_K, STEPS, STAGES)
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step));
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage));
cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV);
};

// g2s tiled copy for k_rope
GmemTiledCopyKRope gmem_tiled_copy_K_rope;
auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx);
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope);
auto produce_k_rope = [&](int ni) {
auto produce_k_rope = [&](int ni, int stage) {
// gK_rope: (BLK_N, ROPE_HEAD_DIM, n)
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES)
auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni));
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage));
cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope);
};

// GEMM-I: S = [email protected]
TiledMma_QK tiled_mma_qk;
auto thr_mma_qk = tiled_mma_qk.get_slice(tidx);
// sQ/sK: (BLK_M, BLK_K, STAGES)
// sQ: (BLK_M, BLK_K, STEPS)
auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{}));
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}));
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope);
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope);
// sK: (BLK_N, BLK_K, STEPS, STAGES)
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}, _0{}));

// s2r tiled copy for q/q_rope
SmemTiledCopyQ smem_tiled_copy_Q;
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx);
// (CPY, CPY_M, CPY_K, STAGES)
// (CPY, CPY_M, CPY_K, STEPS)
auto tCsQ = smem_thr_copy_Q.partition_S(sQ);
// (CPY, CPY_M, CPY_K)
auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ);

// (CPY, CPY_M, CPY_K)
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope);
// (CPY, CPY_M, CPY_K)
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);

// s2r tiled copy for k/k_rope
SmemTiledCopyK smem_tiled_copy_K;
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
// (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsK = smem_thr_copy_K.partition_S(sK);
// (CPY, CPY_N, CPY_K)
auto tCrK = smem_thr_copy_K.retile_D(tSrK);

// (CPY, CPY_N, CPY_K)
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope);
// (CPY, CPY_N, CPY_K)
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope);

// S = [email protected]
// tSrS: (MMA,MMA_M,MMA_N)
auto compute_qk = [&](auto& tSrS, int s) {
// (CPY, CPY_M, CPY_K, STAGES)
auto tCsQ_s = tCsQ(_, _, _, s);
auto tCsK_s = tCsK(_, _, _, s);
auto compute_qk = [&](auto& tSrS, int step, int stage) {
// tCsQ: (CPY, CPY_M, CPY_K, STEPS)
auto tCsQ_s = tCsQ(_, _, _, step);
// TCsK: (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsK_s = tCsK(_, _, _, step, stage);
// prefetch kv
cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{}));
cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{}));
Expand All @@ -274,9 +265,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
}
};

auto compute_qk_rope = [&](auto& tSrS) {
// sQ_rope: (BLK_N, ROPE_HEAD_DIM)
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope);
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES)
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{}));
// (CPY, CPY_M, CPY_K)
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope);
// (CPY, CPY_M, CPY_K)
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope);
// (CPY, CPY_N, CPY_K, STAGES)
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope);
// (CPY, CPY_N, CPY_K)
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope);
auto compute_qk_rope = [&](auto& tSrS, int stage) {
auto tCsK_rope_s = tCsK_rope(_, _, _, stage);
cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{}));
cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{}));
cute::copy(
smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{}));

CUTE_UNROLL
for (int k = 0; k < size<2>(tCsQ_rope); ++k) {
Expand All @@ -286,7 +291,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
tCsQ_rope(_, _, next_k),
tCrQ_rope(_, _, next_k));
cute::copy(smem_tiled_copy_K,
tCsK_rope(_, _, next_k),
tCsK_rope_s(_, _, next_k),
tCrK_rope(_, _, next_k));
}
cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS);
Expand All @@ -298,8 +303,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
auto thr_mma_pv = tiled_mma_pv.get_slice(tidx);
// sP: (BLK_M, BLK_N)
auto tOrP = thr_mma_pv.partition_fragment_A(sP);
// sVt: (BLK_K, BLK_N, STAGES)
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}));
// sVt: (BLK_K, BLK_N, STEPS, STAGES)
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{}));

// s2r tiled copy for p
SmemTiledCopyP smem_tiled_copy_P;
Expand All @@ -312,16 +317,17 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
// s2r tiled copy for vt
SmemTiledCopyVt smem_tiled_copy_Vt;
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx);
// (CPY, CPY_N, CPY_K, STAGES)
// (CPY, CPY_N, CPY_K, STEPS, STAGES)
auto tCsVt = smem_thr_copy_Vt.partition_S(sVt);
// (CPY, CPY_N, CPY_K)
auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt);

// O = P*V = softmax(S)*V
// tOrO: (MMA,MMA_M,MMA_K,STAGES)
auto compute_pv = [&](auto& tOrO, int s) {
auto tOrO_s = tOrO(_, _, _, s);
auto tCsVt_s = tCsVt(_, _, _, s);
// tOrO: (MMA,MMA_M,MMA_K,STEPS)
auto compute_pv = [&](auto& tOrO, int step, int stage) {
auto tOrO_s = tOrO(_, _, _, step);
auto tCsVt_s = tCsVt(_, _, _, step, stage);

cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{}));
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{}));

Expand Down Expand Up @@ -350,16 +356,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(smem_tiled_copy_S, tCrS, tCsS);
};

// tOrO: (MMA,MMA_M,MMA_K,STAGES)
// tOrO: (MMA,MMA_M,MMA_K,STEPS)
auto epilogue = [&](const auto& tOrO) {
// write output to gmem
// 1. copy output from reg to smem (reuse sQ)
SmemTiledCopyO smem_tiled_copy_O;
auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx);
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
auto tOrO_s = tOrO(_, _, _, s);
auto sO_s = sO(_, _, s);
for (int step = 0; step < kSteps; ++step) {
auto tOrO_s = tOrO(_, _, _, step);
auto sO_s = sO(_, _, step);

// cast Accumulator to Element type
auto tOrO_ = make_tensor_like<DType>(tOrO_s);
Expand All @@ -381,29 +387,48 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
cute::copy(gmem_tiled_copy_O, tCsO, tCgO);
};

// output accumulator: (MMA,MMA_M,MMA_K,STAGES)
// output accumulator: (MMA,MMA_M,MMA_K,STEPS)
auto tOrO =
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{});
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STEPS>{});
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout()));
clear(tOrO);

const int n_block_min = 0;
const int n_block_max = cute::ceil_div(size<0>(KV), kBlockN);

// ############### Prologue ###############
// produce q_rope: [] => [q_rope, q...]
// g2s async data copy pipelines
// | stage | queue |
// | 1 | [k_r, kv0, kv1] |
// | 2 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)] |
// | 3 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)*2] |
// ^ kWait = (kSteps + 1) * (2*kStages - 1) - 1
constexpr int kWait = (kSteps + 1) * (2 * kStages - 1) - 1;
// produce q_rope/q
produce_q_rope();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
produce_q(s);
for (int step = 0; step < kSteps; ++step) {
produce_q(step);
}
// produce k_rope: [q_rope, q...] => [q_rope, q..., k_rope, kv...]
produce_k_rope(0);
cp_async_fence();
// produce k_rope/kv
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
produce_kv(0, s);
for (int stage = 0; stage < kStages; ++stage) {
// insert nops between stages for a perfect pipeline
if (stage != 0) {
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
cp_async_fence();
}
}

produce_k_rope(stage, stage);
cp_async_fence();
CUTE_UNROLL
for (int step = 0; step < kSteps; ++step) {
produce_kv(stage, step, stage);
cp_async_fence();
}
}

// ############### Mainloop ###############
Expand All @@ -421,25 +446,26 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
Softmax softmax(params.sm_scale_log2);
Mask mask(q_len, kv_len, group_size, sliding_window);

int stage = 0;

CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
clear(tSrS);

// wait key, queue: [q, q_rope, kv, k_rope] => []
cp_async_wait<kStages>();
cp_async_wait<kWait>();
__syncthreads();

// 1> S = Q_rope@K_rope.T
compute_qk_rope(tSrS);
compute_qk_rope(tSrS, stage);
cp_async_fence();

// 2> S += [email protected]
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
cp_async_wait<kStages>();
for (int step = 0; step < kSteps; ++step) {
cp_async_wait<kWait>();
__syncthreads();

compute_qk(tSrS, s);
compute_qk(tSrS, step, stage);
cp_async_fence();
}

Expand All @@ -452,24 +478,37 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80(
__syncthreads();

// 3> O = softmax(S)*V
const auto next_ni = ni + 1;
const auto next_ni = ni + kStages;
if (next_ni != n_block_max) {
produce_k_rope(next_ni);
produce_k_rope(next_ni, stage);
cp_async_fence();

CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
for (int step = 0; step < kSteps; ++step) {
compute_pv(tOrO, step, stage);

__syncthreads();

produce_kv(next_ni, s);
produce_kv(next_ni, step, stage);
cp_async_fence();
}
} else {
cp_async_fence();
CUTE_UNROLL
for (int s = 0; s < kStages; ++s) {
compute_pv(tOrO, s);
for (int step = 0; step < kSteps; ++step) {
compute_pv(tOrO, step, stage);
cp_async_fence();
}
}

// move to next stage
if constexpr (kStages == 1) {
// do nothing
} else if constexpr (kStages == 2) {
stage = stage ^ 1;
} else {
stage = (stage + 1) % kStages;
}
}

// ############### Epilogue ###############
Expand Down
Loading