diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 4ea7b477..42a675e9 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -13,7 +13,6 @@ #include "mask.h" #include "mla_tile.h" #include "online_softmax.cuh" -#include "ptx.cuh" namespace llm { @@ -31,6 +30,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( constexpr int kBlockN = Traits::kBlockN; constexpr int kBlockK = Traits::kBlockK; constexpr int kHeadDim = Traits::kHeadDim; + constexpr int kSteps = Traits::kSteps; constexpr int kStages = Traits::kStages; constexpr int kRopeHeadDim = Traits::kRopeHeadDim; constexpr int kRowsPerMMA = Traits::kRowsPerMMA; @@ -38,7 +38,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( using _BLK_M = Int; using _BLK_N = Int; using _BLK_K = Int; - using _STAGES = Int; + using _STEPS = Int; using _HEAD_DIM = Int; using _ROPE_HEAD_DIM = Int; @@ -99,12 +99,12 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( } // Gmem - // (BLK_M, BLK_K, STAGES) + // (BLK_M, BLK_K, STEPS) Tensor gQ = local_tile(Q, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); Tensor gO = local_tile(O, Shape<_BLK_M, _BLK_K>{}, make_coord(m_block_idx, _)); - // (BLK_N, BLK_K, n, STAGES) + // (BLK_N, BLK_K, n, STEPS) Tensor gKV = local_tile(KV, Shape<_BLK_N, _BLK_K>{}, make_coord(_, _)); // (BLK_M, ROPE_HEAD_DIM) @@ -123,9 +123,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{}); float* row_sync_smem = (float*)(k_rope_smem + cosize(SmemLayoutKRope{})); - // (BLK_M, BLK_K, STAGES), k-major + // (BLK_M, BLK_K, STEPS), k-major Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); - // (BLK_N, BLK_K, STAGES), k-major + // (BLK_N, BLK_K, STEPS, STAGES), k-major Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); // (BLK_M, BLK_N), k-major @@ -133,14 +133,14 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // (BLK_M, ROPE_HEAD_DIM), k-major Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{}); - // (BLK_N, ROPE_HEAD_DIM), k-major + // (BLK_N, ROPE_HEAD_DIM, STAGES), k-major Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{}); // Tensor for V^t; used in GEMM-II. - // (BLK_K, BLK_N, STAGES) + // (BLK_K, BLK_N, STEPS, STAGES) Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); - // (BLK_M, BLK_K, STAGES), reuse smem + // (BLK_M, BLK_K, STEPS), reuse smem Tensor sO = make_tensor(make_smem_ptr(q_smem), SmemLayoutO{}); // (BLK_M, 2) @@ -180,10 +180,10 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // g2s tiled copy for q GmemTiledCopyQ gmem_tiled_copy_Q; auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_slice(tidx); - auto produce_q = [&](int stage) { - // gQ/sQ: (BLK_M, BLK_K, STAGES) - auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, stage)); - auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, stage)); + auto produce_q = [&](int step) { + // gQ/sQ: (BLK_M, BLK_K, STEPS) + auto tCgQ = gmem_thr_copy_Q.partition_S(gQ(_, _, step)); + auto tCsQ = gmem_thr_copy_Q.partition_D(sQ(_, _, step)); cute::copy(gmem_tiled_copy_Q, tCgQ, tCsQ); }; @@ -199,65 +199,56 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // g2s tiled copy for kv GmemTiledCopyKV gmem_tiled_copy_KV; auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_slice(tidx); - // (CPY, CPY_N, CPY_K, STAGES) - auto produce_kv = [&](int ni, int stage) { - // gKV: (BLK_N, BLK_K, n, STAGES) - auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, stage)); - // sK: (BLK_N, BLK_K, STAGES) - auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, stage)); + auto produce_kv = [&](int ni, int step, int stage) { + // gKV: (BLK_N, BLK_K, n, STEPS) + // sK: (BLK_N, BLK_K, STEPS, STAGES) + auto tCgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni, step)); + auto tCsKV = gmem_thr_copy_KV.partition_D(sK(_, _, step, stage)); cute::copy(gmem_tiled_copy_KV, tCgKV, tCsKV); }; // g2s tiled copy for k_rope GmemTiledCopyKRope gmem_tiled_copy_K_rope; auto gmem_thr_copy_K_rope = gmem_tiled_copy_K_rope.get_slice(tidx); - Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope); - auto produce_k_rope = [&](int ni) { + auto produce_k_rope = [&](int ni, int stage) { + // gK_rope: (BLK_N, ROPE_HEAD_DIM, n) + // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) auto tKgK_rope = gmem_thr_copy_K_rope.partition_S(gK_rope(_, _, ni)); + Tensor tKsK_rope = gmem_thr_copy_K_rope.partition_D(sK_rope(_, _, stage)); cute::copy(gmem_tiled_copy_K_rope, tKgK_rope, tKsK_rope); }; // GEMM-I: S = Q@K.T TiledMma_QK tiled_mma_qk; auto thr_mma_qk = tiled_mma_qk.get_slice(tidx); - // sQ/sK: (BLK_M, BLK_K, STAGES) + // sQ: (BLK_M, BLK_K, STEPS) auto tSrQ = thr_mma_qk.partition_fragment_A(sQ(_, _, _0{})); - auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{})); - auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); - auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope); + // sK: (BLK_N, BLK_K, STEPS, STAGES) + auto tSrK = thr_mma_qk.partition_fragment_B(sK(_, _, _0{}, _0{})); // s2r tiled copy for q/q_rope SmemTiledCopyQ smem_tiled_copy_Q; auto smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx); - // (CPY, CPY_M, CPY_K, STAGES) + // (CPY, CPY_M, CPY_K, STEPS) auto tCsQ = smem_thr_copy_Q.partition_S(sQ); // (CPY, CPY_M, CPY_K) auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); - // (CPY, CPY_M, CPY_K) - auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); - // (CPY, CPY_M, CPY_K) - auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); - // s2r tiled copy for k/k_rope SmemTiledCopyK smem_tiled_copy_K; auto smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx); - // (CPY, CPY_N, CPY_K, STAGES) + // (CPY, CPY_N, CPY_K, STEPS, STAGES) auto tCsK = smem_thr_copy_K.partition_S(sK); // (CPY, CPY_N, CPY_K) auto tCrK = smem_thr_copy_K.retile_D(tSrK); - // (CPY, CPY_N, CPY_K) - auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); - // (CPY, CPY_N, CPY_K) - auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); - // S = Q@K.T // tSrS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrS, int s) { - // (CPY, CPY_M, CPY_K, STAGES) - auto tCsQ_s = tCsQ(_, _, _, s); - auto tCsK_s = tCsK(_, _, _, s); + auto compute_qk = [&](auto& tSrS, int step, int stage) { + // tCsQ: (CPY, CPY_M, CPY_K, STEPS) + auto tCsQ_s = tCsQ(_, _, _, step); + // TCsK: (CPY, CPY_N, CPY_K, STEPS, STAGES) + auto tCsK_s = tCsK(_, _, _, step, stage); // prefetch kv cute::copy(smem_tiled_copy_Q, tCsQ_s(_, _, _0{}), tCrQ(_, _, _0{})); cute::copy(smem_tiled_copy_K, tCsK_s(_, _, _0{}), tCrK(_, _, _0{})); @@ -274,9 +265,23 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( } }; - auto compute_qk_rope = [&](auto& tSrS) { + // sQ_rope: (BLK_N, ROPE_HEAD_DIM) + auto tSrQ_rope = thr_mma_qk.partition_fragment_A(sQ_rope); + // sK_rope: (BLK_N, ROPE_HEAD_DIM, STAGES) + auto tSrK_rope = thr_mma_qk.partition_fragment_B(sK_rope(_, _, _0{})); + // (CPY, CPY_M, CPY_K) + auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + // (CPY, CPY_M, CPY_K) + auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); + // (CPY, CPY_N, CPY_K, STAGES) + auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); + // (CPY, CPY_N, CPY_K) + auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); + auto compute_qk_rope = [&](auto& tSrS, int stage) { + auto tCsK_rope_s = tCsK_rope(_, _, _, stage); cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); - cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); + cute::copy( + smem_tiled_copy_K, tCsK_rope_s(_, _, _0{}), tCrK_rope(_, _, _0{})); CUTE_UNROLL for (int k = 0; k < size<2>(tCsQ_rope); ++k) { @@ -286,7 +291,7 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( tCsQ_rope(_, _, next_k), tCrQ_rope(_, _, next_k)); cute::copy(smem_tiled_copy_K, - tCsK_rope(_, _, next_k), + tCsK_rope_s(_, _, next_k), tCrK_rope(_, _, next_k)); } cute::gemm(tiled_mma_qk, tSrQ_rope(_, _, k), tSrK_rope(_, _, k), tSrS); @@ -298,8 +303,8 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( auto thr_mma_pv = tiled_mma_pv.get_slice(tidx); // sP: (BLK_M, BLK_N) auto tOrP = thr_mma_pv.partition_fragment_A(sP); - // sVt: (BLK_K, BLK_N, STAGES) - auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{})); + // sVt: (BLK_K, BLK_N, STEPS, STAGES) + auto tOrVt = thr_mma_pv.partition_fragment_B(sVt(_, _, _0{}, _0{})); // s2r tiled copy for p SmemTiledCopyP smem_tiled_copy_P; @@ -312,16 +317,17 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( // s2r tiled copy for vt SmemTiledCopyVt smem_tiled_copy_Vt; auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_slice(tidx); - // (CPY, CPY_N, CPY_K, STAGES) + // (CPY, CPY_N, CPY_K, STEPS, STAGES) auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); // (CPY, CPY_N, CPY_K) auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); // O = P*V = softmax(S)*V - // tOrO: (MMA,MMA_M,MMA_K,STAGES) - auto compute_pv = [&](auto& tOrO, int s) { - auto tOrO_s = tOrO(_, _, _, s); - auto tCsVt_s = tCsVt(_, _, _, s); + // tOrO: (MMA,MMA_M,MMA_K,STEPS) + auto compute_pv = [&](auto& tOrO, int step, int stage) { + auto tOrO_s = tOrO(_, _, _, step); + auto tCsVt_s = tCsVt(_, _, _, step, stage); + cute::copy(smem_tiled_copy_P, tCsP(_, _, _0{}), tCrP(_, _, _0{})); cute::copy(smem_tiled_copy_Vt, tCsVt_s(_, _, _0{}), tCrVt(_, _, _0{})); @@ -350,16 +356,16 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( cute::copy(smem_tiled_copy_S, tCrS, tCsS); }; - // tOrO: (MMA,MMA_M,MMA_K,STAGES) + // tOrO: (MMA,MMA_M,MMA_K,STEPS) auto epilogue = [&](const auto& tOrO) { // write output to gmem // 1. copy output from reg to smem (reuse sQ) SmemTiledCopyO smem_tiled_copy_O; auto smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx); CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - auto tOrO_s = tOrO(_, _, _, s); - auto sO_s = sO(_, _, s); + for (int step = 0; step < kSteps; ++step) { + auto tOrO_s = tOrO(_, _, _, step); + auto sO_s = sO(_, _, step); // cast Accumulator to Element type auto tOrO_ = make_tensor_like(tOrO_s); @@ -381,9 +387,9 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( cute::copy(gmem_tiled_copy_O, tCsO, tCgO); }; - // output accumulator: (MMA,MMA_M,MMA_K,STAGES) + // output accumulator: (MMA,MMA_M,MMA_K,STEPS) auto tOrO = - partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STAGES>{}); + partition_fragment_C(tiled_mma_pv, Shape<_BLK_M, _BLK_K, _STEPS>{}); auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_mns(tOrO.layout())); clear(tOrO); @@ -391,19 +397,38 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( const int n_block_max = cute::ceil_div(size<0>(KV), kBlockN); // ############### Prologue ############### - // produce q_rope: [] => [q_rope, q...] + // g2s async data copy pipelines + // | stage | queue | + // | 1 | [k_r, kv0, kv1] | + // | 2 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)] | + // | 3 | [k_r, kv0, kv1, (nop, nop, nop, k_r, kv0, kv1)*2] | + // ^ kWait = (kSteps + 1) * (2*kStages - 1) - 1 + constexpr int kWait = (kSteps + 1) * (2 * kStages - 1) - 1; + // produce q_rope/q produce_q_rope(); CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - produce_q(s); + for (int step = 0; step < kSteps; ++step) { + produce_q(step); } - // produce k_rope: [q_rope, q...] => [q_rope, q..., k_rope, kv...] - produce_k_rope(0); - cp_async_fence(); + // produce k_rope/kv CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - produce_kv(0, s); + for (int stage = 0; stage < kStages; ++stage) { + // insert nops between stages for a perfect pipeline + if (stage != 0) { + cp_async_fence(); + CUTE_UNROLL + for (int step = 0; step < kSteps; ++step) { + cp_async_fence(); + } + } + + produce_k_rope(stage, stage); cp_async_fence(); + CUTE_UNROLL + for (int step = 0; step < kSteps; ++step) { + produce_kv(stage, step, stage); + cp_async_fence(); + } } // ############### Mainloop ############### @@ -421,25 +446,26 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( Softmax softmax(params.sm_scale_log2); Mask mask(q_len, kv_len, group_size, sliding_window); + int stage = 0; + CUTE_NO_UNROLL for (int ni = n_block_min; ni < n_block_max; ++ni) { clear(tSrS); - // wait key, queue: [q, q_rope, kv, k_rope] => [] - cp_async_wait(); + cp_async_wait(); __syncthreads(); // 1> S = Q_rope@K_rope.T - compute_qk_rope(tSrS); + compute_qk_rope(tSrS, stage); cp_async_fence(); // 2> S += Q@K.T CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - cp_async_wait(); + for (int step = 0; step < kSteps; ++step) { + cp_async_wait(); __syncthreads(); - compute_qk(tSrS, s); + compute_qk(tSrS, step, stage); cp_async_fence(); } @@ -452,24 +478,37 @@ __global__ __launch_bounds__(Traits::kThreadNum) void mla_kernel_sm80( __syncthreads(); // 3> O = softmax(S)*V - const auto next_ni = ni + 1; + const auto next_ni = ni + kStages; if (next_ni != n_block_max) { - produce_k_rope(next_ni); + produce_k_rope(next_ni, stage); cp_async_fence(); + CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - compute_pv(tOrO, s); + for (int step = 0; step < kSteps; ++step) { + compute_pv(tOrO, step, stage); + __syncthreads(); - produce_kv(next_ni, s); + produce_kv(next_ni, step, stage); cp_async_fence(); } } else { + cp_async_fence(); CUTE_UNROLL - for (int s = 0; s < kStages; ++s) { - compute_pv(tOrO, s); + for (int step = 0; step < kSteps; ++step) { + compute_pv(tOrO, step, stage); + cp_async_fence(); } } + + // move to next stage + if constexpr (kStages == 1) { + // do nothing + } else if constexpr (kStages == 2) { + stage = stage ^ 1; + } else { + stage = (stage + 1) % kStages; + } } // ############### Epilogue ############### diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index e5bc078d..b315ac24 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -19,19 +19,22 @@ namespace llm { constexpr static int HEAD_DIM_NAME = 128; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 64; \ - constexpr static int BLK_K = 64; \ + constexpr static int BLK_K = 128; \ + constexpr static int STAGES = 2; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 256) { \ constexpr static int HEAD_DIM_NAME = 256; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 32; \ constexpr static int BLK_K = 128; \ + constexpr static int STAGES = 2; \ return __VA_ARGS__(); \ } else if (HEAD_DIM_V <= 512) { \ constexpr static int HEAD_DIM_NAME = 512; \ constexpr static int BLK_M = 64; \ constexpr static int BLK_N = 16; \ constexpr static int BLK_K = 128; \ + constexpr static int STAGES = 1; \ return __VA_ARGS__(); \ } else { \ assert(false); \ @@ -97,7 +100,8 @@ torch::Tensor mla_sm80( ROPE_HEAD_DIM, BLK_M, BLK_N, - BLK_K>; + BLK_K, + STAGES>; launch_mla_kernel_sm80(params, nullptr); }); @@ -158,7 +162,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype ::testing::Values(1, 2, 4, 10), // batch_size ::testing::Values(64), // q_len - ::testing::Values(64, 128), // kv_len + ::testing::Values(64, 128, 1024), // kv_len ::testing::Values(1, 8, 128), // n_heads ::testing::Values(128, 256, 512), // head_dim ::testing::Values(64) // rope_head_dim diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index 3f293fa6..2ab62ffd 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -103,7 +103,8 @@ void mla_bench_sm80(nvbench::state& state) { /*ROPE_HEAD_DIM=*/64, BLK_M, BLK_N, - BLK_K>; + BLK_K, + /*STAGES=*/1>; launch_mla_kernel_sm80(params, launch.get_stream()); }); diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h index 05cff384..090e3aad 100644 --- a/src/kernels/attention/mla_traits_sm80.h +++ b/src/kernels/attention/mla_traits_sm80.h @@ -13,7 +13,7 @@ namespace detail { // Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN struct LayoutConvertor { // Convert fragment layout to rowcol layout for iterating - // (MMA=4, MMA_M, MMA_N, STAGES) => ((2, MMA_M), (2, MMA_N), STAGES) + // (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N)) template CUTE_HOST_DEVICE static constexpr auto to_mn(const LayoutC& layout) { auto l = logical_divide(layout, Shape<_2>{}); @@ -21,6 +21,7 @@ struct LayoutConvertor { make_layout(get<0, 0>(l), get<2>(l))); } + // (MMA=4, MMA_M, MMA_N, STEPS) => ((2, MMA_M), (2, MMA_N), STEPS) template CUTE_HOST_DEVICE static constexpr auto to_mns(const LayoutC& layout) { auto l = logical_divide(layout, Shape<_2>{}); @@ -37,14 +38,15 @@ template + int BLK_K, + int STAGES> struct MLATraitsSM80 { - // helpful aliases static constexpr int kHeadDim = HEAD_DIM; static constexpr int kRopeHeadDim = ROPE_HEAD_DIM; static constexpr int kBlockM = BLK_M; static constexpr int kBlockN = BLK_N; static constexpr int kBlockK = BLK_K; + static constexpr int kStages = STAGES; static constexpr int kRowsPerMMA = 2; static_assert(kHeadDim % 64 == 0); @@ -53,14 +55,18 @@ struct MLATraitsSM80 { static_assert(kBlockM % 64 == 0); static_assert(kBlockN % 16 == 0); static_assert(kBlockK % 64 == 0); + static_assert(kStages == 1 || kStages == 2); static_assert(kHeadDim % kBlockK == 0); - static constexpr int kStages = kHeadDim / kBlockK; + // number of steps per stage + static constexpr int kSteps = kHeadDim / kBlockK; + // helpful aliases using DType = DTYPE; using _BLK_M = Int; using _BLK_N = Int; using _BLK_K = Int; + using _STEPS = Int; using _STAGES = Int; using _HEAD_DIM = Int; using _ROPE_HEAD_DIM = Int; @@ -118,31 +124,31 @@ struct MLATraitsSM80 { SmemLayoutAtom_8x32>; // SMEM layout for Q/K/V/P - // Q smem: (BLK_M, BLK_K, STAGES) + // Q smem: (BLK_M, BLK_K, STEPS) using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_M, _BLK_K, _STAGES>{})); + Shape<_BLK_M, _BLK_K, _STEPS>{})); - // KV smem: (BLK_N, BLK_K, STAGES) + // KV smem: (BLK_N, BLK_K, STEPS, STAGES) using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_N, _BLK_K, _STAGES>{})); + Shape<_BLK_N, _BLK_K, _STEPS, _STAGES>{})); // P smem: (BLK_M, BLK_N) using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomN{}, Shape<_BLK_M, _BLK_N>{})); - // V^T smem: (BLK_K, BLK_N, STAGES) - using SmemLayoutVt = decltype(permute<1, 0, 2>(SmemLayoutKV{})); + // V^T smem: (BLK_K, BLK_N, STEPS, STAGES) + using SmemLayoutVt = decltype(permute<1, 0, 2, 3>(SmemLayoutKV{})); // QRope smem: (BLK_M, ROPE_HEAD_DIM) using SmemLayoutQRope = decltype(tile_to_shape(SmemLayoutAtomR{}, Shape<_BLK_M, _ROPE_HEAD_DIM>{})); - // KRoep smem: (BLK_N, ROPE_HEAD_DIM) + // KRoep smem: (BLK_N, ROPE_HEAD_DIM, STAGES) using SmemLayoutKRope = decltype(tile_to_shape(SmemLayoutAtomR{}, - Shape<_BLK_N, _ROPE_HEAD_DIM>{})); + Shape<_BLK_N, _ROPE_HEAD_DIM, _STAGES>{})); // Shared memory for reduce between warps // rowmax/rowsum smem: (_BLK_M, _2) @@ -225,9 +231,9 @@ struct MLATraitsSM80 { // ******* Epilogue ******* - // O smem: (BLK_M, BLK_K, STAGES) + // O smem: (BLK_M, BLK_K, STEPS) using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomK{}, - Shape<_BLK_M, _BLK_K, _STAGES>{})); + Shape<_BLK_M, _BLK_K, _STEPS>{})); // s2g tiled copy for O (32x64) using GmemTiledCopyO = decltype(make_tiled_copy( diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp index c094d6be..69a7132d 100644 --- a/src/kernels/attention/mla_traits_test.cpp +++ b/src/kernels/attention/mla_traits_test.cpp @@ -59,7 +59,8 @@ TEST(MLATraitsTest, TraitsSM80) { /*ROPE_HEAD_DIM=*/64, /*BLK_M=*/64, /*BLK_N=*/64, - /*BLK_K=*/64>>(); + /*BLK_K=*/64, + /*STAGES=*/1>>(); } } // namespace llm \ No newline at end of file