-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
kernel: added stage support for MLA kernel
- Loading branch information
Showing
5 changed files
with
134 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,6 @@ | |
#include "mask.h" | ||
#include "mla_tile.h" | ||
#include "online_softmax.cuh" | ||
#include "ptx.cuh" | ||
|
||
namespace llm { | ||
|
||
|
@@ -31,14 +30,15 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
constexpr int kBlockN = Traits::kBlockN; | ||
constexpr int kBlockK = Traits::kBlockK; | ||
constexpr int kHeadDim = Traits::kHeadDim; | ||
constexpr int kSteps = Traits::kSteps; | ||
constexpr int kStages = Traits::kStages; | ||
constexpr int kRopeHeadDim = Traits::kRopeHeadDim; | ||
constexpr int kRowsPerMMA = Traits::kRowsPerMMA; | ||
|
||
using _BLK_M = Int<kBlockM>; | ||
using _BLK_N = Int<kBlockN>; | ||
using _BLK_K = Int<kBlockK>; | ||
using _STAGES = Int<kStages>; | ||
using _STEPS = Int<kSteps>; | ||
using _HEAD_DIM = Int<kHeadDim>; | ||
using _ROPE_HEAD_DIM = Int<kRopeHeadDim>; | ||
|
||
|
@@ -99,12 +99,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
} | ||
|
||
// Gmem | ||
// (BLK_M, BLK_K, STAGES) | ||
// (BLK_M, BLK_K, STEPS) | ||
Tensor gQ = | ||
local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); | ||
Tensor gO = | ||
local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); | ||
// (BLK_N, BLK_K, n, STAGES) | ||
// (BLK_N, BLK_K, n, STEPS) | ||
Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _)); | ||
|
||
// (BLK_M, ROPE_HEAD_DIM) | ||
|
@@ -123,24 +123,24 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{}); | ||
float* row_sync_smem = (float*)(k_rope_smem + cosize(SmemLayoutKRope{})); | ||
|
||
// (BLK_M, BLK_K, STAGES), k-major | ||
// (BLK_M, BLK_K, STEPS), k-major | ||
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); | ||
// (BLK_N, BLK_K, STAGES), k-major | ||
// (BLK_N, BLK_K, STEPS, STAGES), k-major | ||
Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); | ||
|
||
// (BLK_M, BLK_N), k-major | ||
Tensor sP = make_tensor(make_smem_ptr(p_smem), SmemLayoutP{}); | ||
|
||
// (BLK_M, ROPE_HEAD_DIM), k-major | ||
Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{}); | ||
// (BLK_N, ROPE_HEAD_DIM), k-major | ||
// (BLK_N, ROPE_HEAD_DIM, STAGES), k-major | ||
Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{}); | ||
|
||
// Tensor for V^t; used in GEMM-II. | ||
// (BLK_K, BLK_N, STAGES) | ||
// (BLK_K, BLK_N, STEPS, STAGES) | ||
Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); | ||
|
||
// (BLK_M, BLK_K, STAGES), reuse smem | ||
// (BLK_M, BLK_K, STEPS), reuse smem | ||
Tensor sO = make_tensor(make_smem_ptr(q_smem), SmemLayoutO{}); | ||
|
||
// (BLK_M, 2) | ||
|
@@ -180,10 +180,10 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
// g2s tiled copy for q | ||
GmemTiledCopyQ gmem_tiled_copy_Q; | ||
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); | ||
auto produce_q = [&](int stage) { | ||
// gQ/sQ: (BLK_M, BLK_K, STAGES) | ||
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, stage)); | ||
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, stage)); | ||
auto produce_q = [&](int step) { | ||
// gQ/sQ: (BLK_M, BLK_K, STEPS) | ||
auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step)); | ||
auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step)); | ||
cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ); | ||
}; | ||
|
||
|
@@ -199,65 +199,56 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
// g2s tiled copy for kv | ||
GmemTiledCopyKV gmem_tiled_copy_KV; | ||
auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); | ||
// (CPY, CPY_N, CPY_K, STAGES) | ||
auto produce_kv = [&](int ni, int stage) { | ||
// gKV: (BLK_N, BLK_K, n, STAGES) | ||
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, stage)); | ||
// sK: (BLK_N, BLK_K, STAGES) | ||
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, stage)); | ||
auto produce_kv = [&](int ni, int step, int stage) { | ||
// gKV: (BLK_N, BLK_K, n, STEPS) | ||
// sK: (BLK_N, BLK_K, STEPS, STAGES) | ||
auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step)); | ||
auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage)); | ||
cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV); | ||
}; | ||
|
||
// g2s tiled copy for k_rope | ||
GmemTiledCopyKRope gmem_tiled_copy_K_rope; | ||
auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx); | ||
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope); | ||
auto produce_k_rope = [&](int ni) { | ||
auto produce_k_rope = [&](int ni, int stage) { | ||
// gK_rope: (BLK_N, ROPE_HEAD_DIM, n) | ||
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) | ||
auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni)); | ||
Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); | ||
cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope); | ||
}; | ||
|
||
// GEMM-I: S = [email protected] | ||
TiledMma_QK tiled_mma_qk; | ||
auto thr_mma_qk = tiled_mma_qk.get_slice(tidx); | ||
// sQ/sK: (BLK_M, BLK_K, STAGES) | ||
// sQ: (BLK_M, BLK_K, STEPS) | ||
auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{})); | ||
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{})); | ||
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); | ||
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope); | ||
// sK: (BLK_N, BLK_K, STEPS, STAGES) | ||
auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}, _0{})); | ||
|
||
// s2r tiled copy for q/q_rope | ||
SmemTiledCopyQ smem_tiled_copy_Q; | ||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); | ||
// (CPY, CPY_M, CPY_K, STAGES) | ||
// (CPY, CPY_M, CPY_K, STEPS) | ||
auto tCsQ = smem_thr_copy_Q.partition_S(sQ); | ||
// (CPY, CPY_M, CPY_K) | ||
auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); | ||
|
||
// (CPY, CPY_M, CPY_K) | ||
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); | ||
// (CPY, CPY_M, CPY_K) | ||
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); | ||
|
||
// s2r tiled copy for k/k_rope | ||
SmemTiledCopyK smem_tiled_copy_K; | ||
auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); | ||
// (CPY, CPY_N, CPY_K, STAGES) | ||
// (CPY, CPY_N, CPY_K, STEPS, STAGES) | ||
auto tCsK = smem_thr_copy_K.partition_S(sK); | ||
// (CPY, CPY_N, CPY_K) | ||
auto tCrK = smem_thr_copy_K.retile_D(tSrK); | ||
|
||
// (CPY, CPY_N, CPY_K) | ||
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); | ||
// (CPY, CPY_N, CPY_K) | ||
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); | ||
|
||
// S = [email protected] | ||
// tSrS: (MMA,MMA_M,MMA_N) | ||
auto compute_qk = [&](auto& tSrS, int s) { | ||
// (CPY, CPY_M, CPY_K, STAGES) | ||
auto tCsQ_s = tCsQ(_, _, _, s); | ||
auto tCsK_s = tCsK(_, _, _, s); | ||
auto compute_qk = [&](auto& tSrS, int step, int stage) { | ||
// tCsQ: (CPY, CPY_M, CPY_K, STEPS) | ||
auto tCsQ_s = tCsQ(_, _, _, step); | ||
// TCsK: (CPY, CPY_N, CPY_K, STEPS, STAGES) | ||
auto tCsK_s = tCsK(_, _, _, step, stage); | ||
// prefetch kv | ||
cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{})); | ||
|
@@ -274,9 +265,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
} | ||
}; | ||
|
||
auto compute_qk_rope = [&](auto& tSrS) { | ||
// sQ_rope: (BLK_N, ROPE_HEAD_DIM) | ||
auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); | ||
// sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) | ||
auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{})); | ||
// (CPY, CPY_M, CPY_K) | ||
auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); | ||
// (CPY, CPY_M, CPY_K) | ||
auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); | ||
// (CPY, CPY_N, CPY_K, STAGES) | ||
auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); | ||
// (CPY, CPY_N, CPY_K) | ||
auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); | ||
auto compute_qk_rope = [&](auto& tSrS, int stage) { | ||
auto tCsK_rope_s = tCsK_rope(_, _, _, stage); | ||
cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); | ||
cute::copy( | ||
smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{})); | ||
|
||
CUTE_UNROLL | ||
for (int k = 0; k < size<2>(tCsQ_rope); ++k) { | ||
|
@@ -286,7 +291,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
tCsQ_rope(_, _, next_k), | ||
tCrQ_rope(_, _, next_k)); | ||
cute::copy(smem_tiled_copy_K, | ||
tCsK_rope(_, _, next_k), | ||
tCsK_rope_s(_, _, next_k), | ||
tCrK_rope(_, _, next_k)); | ||
} | ||
cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS); | ||
|
@@ -298,8 +303,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
auto thr_mma_pv = tiled_mma_pv.get_slice(tidx); | ||
// sP: (BLK_M, BLK_N) | ||
auto tOrP = thr_mma_pv.partition_fragment_A(sP); | ||
// sVt: (BLK_K, BLK_N, STAGES) | ||
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{})); | ||
// sVt: (BLK_K, BLK_N, STEPS, STAGES) | ||
auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{})); | ||
|
||
// s2r tiled copy for p | ||
SmemTiledCopyP smem_tiled_copy_P; | ||
|
@@ -312,16 +317,17 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
// s2r tiled copy for vt | ||
SmemTiledCopyVt smem_tiled_copy_Vt; | ||
auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); | ||
// (CPY, CPY_N, CPY_K, STAGES) | ||
// (CPY, CPY_N, CPY_K, STEPS, STAGES) | ||
auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); | ||
// (CPY, CPY_N, CPY_K) | ||
auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); | ||
|
||
// O = P*V = softmax(S)*V | ||
// tOrO: (MMA,MMA_M,MMA_K,STAGES) | ||
auto compute_pv = [&](auto& tOrO, int s) { | ||
auto tOrO_s = tOrO(_, _, _, s); | ||
auto tCsVt_s = tCsVt(_, _, _, s); | ||
// tOrO: (MMA,MMA_M,MMA_K,STEPS) | ||
auto compute_pv = [&](auto& tOrO, int step, int stage) { | ||
auto tOrO_s = tOrO(_, _, _, step); | ||
auto tCsVt_s = tCsVt(_, _, _, step, stage); | ||
|
||
cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{})); | ||
cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{})); | ||
|
||
|
@@ -350,16 +356,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
cute::copy(smem_tiled_copy_S, tCrS, tCsS); | ||
}; | ||
|
||
// tOrO: (MMA,MMA_M,MMA_K,STAGES) | ||
// tOrO: (MMA,MMA_M,MMA_K,STEPS) | ||
auto epilogue = [&](const auto& tOrO) { | ||
// write output to gmem | ||
// 1. copy output from reg to smem (reuse sQ) | ||
SmemTiledCopyO smem_tiled_copy_O; | ||
auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx); | ||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
auto tOrO_s = tOrO(_, _, _, s); | ||
auto sO_s = sO(_, _, s); | ||
for (int step = 0; step < kSteps; ++step) { | ||
auto tOrO_s = tOrO(_, _, _, step); | ||
auto sO_s = sO(_, _, step); | ||
|
||
// cast Accumulator to Element type | ||
auto tOrO_ = make_tensor_like<DType>(tOrO_s); | ||
|
@@ -381,29 +387,32 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
cute::copy(gmem_tiled_copy_O, tCsO, tCgO); | ||
}; | ||
|
||
// output accumulator: (MMA,MMA_M,MMA_K,STAGES) | ||
// output accumulator: (MMA,MMA_M,MMA_K,STEPS) | ||
auto tOrO = | ||
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{}); | ||
partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STEPS>{}); | ||
auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout())); | ||
clear(tOrO); | ||
|
||
const int n_block_min = 0; | ||
const int n_block_max = cute::ceil_div(size<0>(KV), kBlockN); | ||
|
||
// ############### Prologue ############### | ||
// produce q_rope: [] => [q_rope, q...] | ||
// produce q_rope/q: [] => [q_rope, q...] | ||
produce_q_rope(); | ||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
produce_q(s); | ||
for (int step = 0; step < kSteps; ++step) { | ||
produce_q(step); | ||
} | ||
// produce k_rope: [q_rope, q...] => [q_rope, q..., k_rope, kv...] | ||
produce_k_rope(0); | ||
cp_async_fence(); | ||
// produce k_rope/kv: [q_rope, q...] => [q_rope, q..., k_rope, kv...] | ||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
produce_kv(0, s); | ||
for (int stage = 0; stage < kStages; ++stage) { | ||
produce_k_rope(stage, stage); | ||
cp_async_fence(); | ||
CUTE_UNROLL | ||
for (int step = 0; step < kSteps; ++step) { | ||
produce_kv(stage, step, stage); | ||
cp_async_fence(); | ||
} | ||
} | ||
|
||
// ############### Mainloop ############### | ||
|
@@ -421,25 +430,29 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
Softmax softmax(params.sm_scale_log2); | ||
Mask mask(q_len, kv_len, group_size, sliding_window); | ||
|
||
constexpr int kWait = kStages * (kSteps + 1) - 1; | ||
int stage = 0; | ||
|
||
CUTE_NO_UNROLL | ||
for (int ni = n_block_min; ni < n_block_max; ++ni) { | ||
clear(tSrS); | ||
|
||
// wait key, queue: [q, q_rope, kv, k_rope] => [] | ||
cp_async_wait<kStages>(); | ||
// wait queue: [q_rope, q..., (k_rope, kv...), (k_rope, kv...)] | ||
// => [kv..., (k_rope, kv...)] | ||
cp_async_wait<kWait>(); | ||
__syncthreads(); | ||
|
||
// 1> S = Q_rope@K_rope.T | ||
compute_qk_rope(tSrS); | ||
compute_qk_rope(tSrS, stage); | ||
cp_async_fence(); | ||
|
||
// 2> S += [email protected] | ||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
cp_async_wait<kStages>(); | ||
for (int step = 0; step < kSteps; ++step) { | ||
cp_async_wait<kWait>(); | ||
__syncthreads(); | ||
|
||
compute_qk(tSrS, s); | ||
compute_qk(tSrS, step, stage); | ||
cp_async_fence(); | ||
} | ||
|
||
|
@@ -452,24 +465,37 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( | |
__syncthreads(); | ||
|
||
// 3> O = softmax(S)*V | ||
const auto next_ni = ni + 1; | ||
const auto next_ni = ni + kStages; | ||
if (next_ni != n_block_max) { | ||
produce_k_rope(next_ni); | ||
produce_k_rope(next_ni, stage); | ||
cp_async_fence(); | ||
|
||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
compute_pv(tOrO, s); | ||
for (int step = 0; step < kSteps; ++step) { | ||
compute_pv(tOrO, step, stage); | ||
|
||
__syncthreads(); | ||
|
||
produce_kv(next_ni, s); | ||
produce_kv(next_ni, step, stage); | ||
cp_async_fence(); | ||
} | ||
} else { | ||
cp_async_fence(); | ||
CUTE_UNROLL | ||
for (int s = 0; s < kStages; ++s) { | ||
compute_pv(tOrO, s); | ||
for (int step = 0; step < kSteps; ++step) { | ||
compute_pv(tOrO, step, stage); | ||
cp_async_fence(); | ||
} | ||
} | ||
|
||
// move to next stage | ||
if constexpr (kStages == 1) { | ||
// do nothing | ||
} else if constexpr (kStages == 2) { | ||
stage = stage ^ 1; | ||
} else { | ||
stage = (stage + 1) % kStages; | ||
} | ||
} | ||
|
||
// ############### Epilogue ############### | ||
|
Oops, something went wrong.