From 4d424c0edbbe73e53aafc1e9b7b88c3549fac132 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 4 Feb 2025 16:04:14 -0800 Subject: [PATCH 1/9] kernel: added MLA kernel --- src/kernels/attention/CMakeLists.txt | 16 + src/kernels/attention/mla_kernel_sm80.cuh | 459 ++++++++++++++++++ src/kernels/attention/mla_kernel_sm80_test.cu | 112 +++++ src/kernels/attention/mla_params.h | 57 +++ src/kernels/attention/mla_ref.h | 71 +++ src/kernels/attention/mla_tile.h | 82 ++++ src/kernels/attention/mla_traits_sm80.h | 145 ++++++ 7 files changed, 942 insertions(+) create mode 100644 src/kernels/attention/mla_kernel_sm80.cuh create mode 100644 src/kernels/attention/mla_kernel_sm80_test.cu create mode 100644 src/kernels/attention/mla_params.h create mode 100644 src/kernels/attention/mla_ref.h create mode 100644 src/kernels/attention/mla_tile.h create mode 100644 src/kernels/attention/mla_traits_sm80.h diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index 5faec0be..b3dc5aec 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -17,6 +17,10 @@ cc_library( mha_traits_sm80.h mha_kernel_sm80.cuh mha_dispatch_sm80.cuh + mla_params.h + mla_tile.h + mla_traits_sm80.h + mla_kernel_sm80.cuh DEPS cutlass ) @@ -67,6 +71,18 @@ cc_test( torch ) +cc_test( + NAME + mla_kernel_test + SRCS + mla_kernel_sm80_test.cu + DEPS + :attention.template + absl::random_random + GTest::gtest_main + torch +) + nvbench_binary( NAME mha_sm80_bench diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh new file mode 100644 index 00000000..a5145da9 --- /dev/null +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -0,0 +1,459 @@ +#pragma once + +#include +#include + +#include +#include + +#include "cute/config.hpp" +#include "cute_extensions.cuh" +#include "fast_cast.cuh" +#include "mask.h" +#include "mla_tile.h" +#include "online_softmax.cuh" +#include "ptx.cuh" + +namespace llm { + +template +__global__ void mla_kernel_sm80(__grid_constant__ const Params params) { + using namespace cute; + + constexpr int kBlockM = Traits::kBlockM; + constexpr int kBlockN = Traits::kBlockN; + constexpr int kHeadDim = Traits::kHeadDim; + constexpr int kRowsPerMMA = Traits::kRowsPerMMA; + + using _BLK_M = Int; + using _BLK_N = Int; + using _HEAD_DIM = Int; + + // type alias + using DType = typename Traits::DType; + + using TiledMma = typename Traits::TiledMma; + using Layout = typename Traits::LayoutConvertor; + + using SmemLayoutQ = typename Traits::SmemLayoutQ; + using SmemLayoutK = typename Traits::SmemLayoutK; + using SmemLayoutV = typename Traits::SmemLayoutV; + using SmemLayoutVt = typename Traits::SmemLayoutVt; + using SmemLayoutO = typename Traits::SmemLayoutO; + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; + using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; + using GmemTiledCopyO = typename Traits::GmemTiledCopyO; + + using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; + using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; + using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + + const int m_block = blockIdx.x; + const int batch_idx = blockIdx.y; + const int kv_head_idx = blockIdx.z; + const int tidx = threadIdx.x; + + MLATile tile(params); + + // preprocess input parameters + const int head_dim = params.head_dim; + const int group_size = params.group_size; + const float logits_soft_cap = params.logits_soft_cap; + const float sm_scale = params.sm_scale; + const float sm_scale_log2 = params.sm_scale_log2; + + // ProblemShape + // (q_packed_len, HEAD_DIM) + auto [Q, O] = tile.template get_qo_tile(batch_idx, kv_head_idx); + // (kv_len, HEAD_DIM) + auto [K, V] = tile.template get_kv_tile(batch_idx, kv_head_idx); + + const int q_packed_len = size<0>(Q); + const int q_len = q_packed_len / group_size; + const int kv_len = size<0>(K); + + if (m_block * kBlockM >= q_packed_len) { + // m out of bound, return + return; + } + + const int sliding_window = LOCAL ? params.sliding_window : kv_len; + + // Gmem + // (BLK_M, HEAD_DIM) + Tensor gQ = + local_tile(Q, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block, _0{})); + Tensor gO = + local_tile(O, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block, _0{})); + // (BLK_N, HEAD_DIM, n) + Tensor gK = local_tile(K, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); + Tensor gV = local_tile(V, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); + + // Smem + extern __shared__ char smem[]; + DType* q_smem = (DType*)smem; + DType* k_smem = q_smem + cosize(SmemLayoutQ{}); + DType* v_smem = k_smem + cosize(SmemLayoutK{}); + + // (BLK_M, HEAD_DIM), k-major + Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); + // (BLK_N, HEAD_DIM), k-major + Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{}); + + // Tensor for V^t; used in GEMM-II. + // (HEAD_DIM, BLK_N), m-major + Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{}); + + // Tiled Copy + // g2s tiled copy for qkv + GmemTiledCopyQ gmem_tiled_copy_Q; + GmemTiledCopyKV gmem_tiled_copy_KV; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); + auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); + + // coordinate tensor for oob handling + // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) + Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); + + auto produce_query = [&]() { + auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); + auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); + auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim); + safe_copy( + gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); + }; + + // (BLK_N, HEAD_DIM) -> (blk_n, head_dim) + Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{}); + Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); + + Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); + auto produce_key = [&](int ni) { + auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + // skip ZFILL_MN for key since Mask will mask out oob with -inf + safe_copy( + gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); + }; + + // produce key without oob handling + auto produce_key_no_oob = [&](int ni) { + auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + safe_copy( + gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); + }; + + Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); + auto produce_value = [&](int ni) { + auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + // skipping ZFILL_MN for v may cause nan issue + safe_copy( + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + }; + + // produce value without oob handling + auto produce_value_no_oob = [&](int ni) { + auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); + auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); + safe_copy( + gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + }; + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(tidx); + // GEMM-I: S = Q@K.T + auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + // s2r tiled copy for qkv + SmemTiledCopyQ smem_tiled_copy_Q; + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + auto tSsQ = smem_thr_copy_Q.partition_S(sQ); + auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + + SmemTiledCopyK smem_tiled_copy_K; + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + auto tSsK = smem_thr_copy_K.partition_S(sK); + auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + + // S = Q@K.T + // tSrAccS: (MMA,MMA_M,MMA_N) + auto compute_qk = [&](auto& tSrAccS) { + // prefetch kv + cute::copy(smem_tiled_copy_Q, tSsQ(_, _, _0{}), tSrQ_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); + + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tSrQ); ++ki) { + // prefetch next kv + if (ki != size<2>(tSrQ) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Q, + tSsQ(_, _, next_ki), + tSrQ_copy_view(_, _, next_ki)); + cute::copy(smem_tiled_copy_K, + tSsK(_, _, next_ki), + tSrK_copy_view(_, _, next_ki)); + } + cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); + } + }; + + // GEMM-II: O = softmax(S)@V + auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) + + SmemTiledCopyVt smem_tiled_copy_Vt; + auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); + auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); + auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); + + // O = softmax(S)*V + // tSrAccS: (MMA,MMA_M,MMA_N) + // tOrAccO: (MMA,MMA_M,MMA_K) + auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { + // cast scores from Accumulator to Element + auto tSrS = make_tensor_like(tSrAccS); + fast_cast(tSrAccS, tSrS); + + // convert layout from gemm-I C to gemm-II A + auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout())); + + // prefetch V^t + cute::copy( + smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tOrS); ++ki) { + // prefetch next V^t + if (ki != size<2>(tOrS) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Vt, + tOsVt(_, _, next_ki), + tOrVt_copy_view(_, _, next_ki)); + } + cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); + } + }; + + // tOrAccO: (MMA,MMA_M,MMA_K) + auto epilogue = [&](const auto& tOrAccO) { + // write output to gmem + // 1> cast output from ElementAccumulator to Element + auto tOrO = make_tensor_like(tOrAccO); + fast_cast(tOrAccO, tOrO); + + // 2. copy output from reg to smem (reuse sQ) + auto sO = make_tensor(sQ.data(), SmemLayoutO{}); + + SmemTiledCopyO smem_tiled_copy_O; + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + auto taccOrO = smem_thr_copy_O.retile_S(tOrO); + auto taccOsO = smem_thr_copy_O.partition_D(sO); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + // 3. copy output from smem to gmem + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + + // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) + auto cO = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); + + auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) + // (CPY,CPY_M,CPY_K) -> (blk_m, head_dim) + auto tOcO = gmem_thr_copy_O.partition_D(cO); + + // wait for smem copy done before gmem copy + __syncthreads(); + + auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim); + safe_copy( + gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); + }; + + // output accumulator, (MMA,MMA_M,MMA_K) + auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); + auto tOrAccO_rc_view = + make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout())); + clear(tOrAccO); + + const int diagonal = (m_block * kBlockM) / group_size + kv_len - q_len; + // process kv in range: [kv_idx_min, kv_idx_max) + const int kv_idx_min = std::max(0, diagonal - sliding_window); + const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); + const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0; + const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); + + if (n_block_min >= n_block_max) { + // write output to gmem + epilogue(tOrAccO); + return; + } + + // ############### Prologue ############### + int n_block_idx = n_block_max - 1; + // produce query: [] => [q] + produce_query(); + cp_async_fence(); + // produce key: [q] => [q, k] + produce_key(n_block_idx); + cp_async_fence(); + + // ############### Mainloop ############### + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrAccS_rc_view = + make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); + + auto apply_logits_soft_cap = [&](auto& tSrAccS) { + if constexpr (SOFT_CAP) { + CUTE_UNROLL + for (int i = 0; i < size(tSrAccS); ++i) { + tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap); + } + } + }; + + constexpr int kMMA_M = size<1>(tSrAccS); + using Softmax = OnlineSoftmax; + using Mask = Mask; + + Softmax softmax(sm_scale_log2); + Mask mask(tidx, + m_block, + q_len, + kv_len, + kv_head_idx, + group_size, + sliding_window, + sm_scale, + params.alibi_slopes_ptr); + + // seperate oob mask iterations for better performance + constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; + + // oob mask iterations + CUTE_UNROLL + for (int i = 0; i < n_oob_mask; ++i) { + clear(tSrAccS); + + // wait key, queue: [q, k] => [] + cp_async_wait<0>(); + __syncthreads(); + + // produce value, [] => [v] + if (i == 0) { + produce_value(n_block_idx); + } else { + produce_value_no_oob(n_block_idx); + } + cp_async_fence(); + + // 1> S = Q@K.T + compute_qk(tSrAccS); + + if constexpr (SOFT_CAP) { + apply_logits_soft_cap(tSrAccS); + } + mask.apply(tSrAccS_rc_view, n_block_idx); + softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); + + // wait value, [v] => [] + cp_async_wait<0>(); + __syncthreads(); + + // produce next key: [] => [k] + if (n_block_idx > n_block_min) { + produce_key_no_oob(n_block_idx - 1); + } + cp_async_fence(); + + // 2> O = softmax(S)*V + compute_sv(tSrAccS, tOrAccO); + + --n_block_idx; + if (n_block_idx < n_block_min) { + // no more kv blocks to process + break; + } + } + + // non-oob mask iterations + CUTE_NO_UNROLL + for (; n_block_idx >= n_block_min; --n_block_idx) { + clear(tSrAccS); + + // wait key, queue: [q, k] => [] + cp_async_wait<0>(); + __syncthreads(); + + // produce value, [] => [v] + produce_value_no_oob(n_block_idx); + cp_async_fence(); + + // 1> S = Q@K.T + compute_qk(tSrAccS); + + if constexpr (SOFT_CAP) { + apply_logits_soft_cap(tSrAccS); + } + mask.apply(tSrAccS_rc_view, n_block_idx); + softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); + + // wait value, [v] => [] + cp_async_wait<0>(); + __syncthreads(); + + // produce next key: [] => [k] + if (n_block_idx > n_block_min) { + produce_key_no_oob(n_block_idx - 1); + } + cp_async_fence(); + + // 2> O = softmax(S)*V + compute_sv(tSrAccS, tOrAccO); + } + + // ############### Epilogue ############### + + // normalize output: o /= rowsum + softmax.finalize(tOrAccO_rc_view); + + // write output to gmem + epilogue(tOrAccO); +} + +template +void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { + const auto batch_size = params.batch_size; + const auto n_kv_heads = params.n_kv_heads; + const auto max_q_packed_len = params.max_q_len * params.group_size; + + const auto smem_size = Traits::kSmemSize; + auto mla_kernel = + mla_kernel_sm80; + cudaFuncSetAttribute( + mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // TODO: support persistent kernels + dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), + batch_size, + n_kv_heads); + dim3 block = Traits::kThreadNum; + mla_kernel<<>>(params); +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu new file mode 100644 index 00000000..6f4ca917 --- /dev/null +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -0,0 +1,112 @@ +#include +#include + +#include + +#include "cute/layout.hpp" +#include "mla_kernel_sm80.cuh" // IWYU pragma: keep +#include "mla_params.h" +#include "mla_ref.h" + +namespace llm { + +namespace { +torch::Tensor mla_sm80( + torch::Tensor query, // [batch_size, q_len, n_heads, head_dim] + torch::Tensor key, // [batch_size, kv_len, n_kv_heads, head_dim] + torch::Tensor value, // [batch_size, kv_len, n_kv_heads, head_dim] + torch::optional alibi_slopes, //[n_heads] + float logits_soft_cap, + int32_t sliding_window, + int32_t max_q_len) { + const auto batch_size = query.size(0); + const auto q_len = query.size(-3); + const auto kv_len = key.size(-3); + const auto n_heads = query.size(-2); + const auto n_kv_heads = key.size(-2); + const auto head_dim = query.size(-1); + + auto out = torch::empty_like(query); + + const float sm_scale = 1.0 / sqrt(head_dim); + + // construct attention params + MLAParams params; + params.q_ptr = query.const_data_ptr(); + params.q_stride = + make_stride(query.stride(0), query.stride(1), query.stride(2)); + params.k_ptr = key.const_data_ptr(); + params.kv_stride = make_stride(key.stride(0), key.stride(1), key.stride(2)); + + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + params.batch_size = batch_size; + params.max_q_len = max_q_len; + params.n_heads = n_heads; + params.q_len = q_len; + params.kv_len = kv_len; + // params.head_dim = head_dim; + + // DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { + // DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + // run_mha_kernel_sm80(params); + // }); + // }); + return out; +} + +} // namespace + +class MLAKernelTest + : public ::testing::TestWithParam> { + public: + void SetUp() override { + // Set random seed for test stability + torch::manual_seed(0); + } +}; + +TEST_P(MLAKernelTest, MLA) { + const auto [dtype, batch_size, q_len, kv_len, n_heads, head_dim] = GetParam(); + + const auto options = torch::dtype(dtype).device(torch::kCUDA); + + // construct non-contiguous query, key and value + // const auto data = torch::randn( + // {batch_size, q_len, n_heads + 2 * n_kv_heads, head_dim}, options); + // const auto qkv = + // data.split(/*split_size=*/{n_heads, n_kv_heads, n_kv_heads}, /*dim=*/2); + // const auto& query = qkv[0]; + // const auto& key = qkv[1]; + // const auto& value = qkv[2]; + + // auto ref_out = mla_batch_ref( + // query, key, value, alibi_slopes, logits_soft_cap, sliding_window); + // auto out = mla_sm80( + // query, key, value, alibi_slopes, logits_soft_cap, sliding_window, q_len); + + // if (dtype == torch::kBFloat16) { + // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2)); + // } else { + // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + // } +} + +INSTANTIATE_TEST_SUITE_P( + MLA, + MLAKernelTest, + ::testing::Combine(::testing::Values(torch::kHalf), // q_dtype + ::testing::Values(1), // batch_size + ::testing::Values(64), // q_len + ::testing::Values(64), // kv_len + ::testing::Values(8), // n_heads + ::testing::Values(64) // head_dim + )); + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h new file mode 100644 index 00000000..8abf1dd4 --- /dev/null +++ b/src/kernels/attention/mla_params.h @@ -0,0 +1,57 @@ +#pragma once + +#include +#include + +#include "cute/layout.hpp" +namespace llm { + +// common params for attention kernels +struct MLAParamsCommon { + const void* __restrict__ q_ptr = nullptr; + const void* __restrict__ k_ptr = nullptr; + const void* __restrict__ v_ptr = nullptr; + void* __restrict__ o_ptr = nullptr; + + // input shapes + int batch_size = 0; + int n_heads = 0; + // int n_kv_heads = 0; + // int head_dim = 0; + + int qk_nope_head_dim = 0; + int qk_rope_head_dim = 0; + int v_head_dim = 0; + + // used for scheduling + // TODO: remove it after persistent kernel + int max_q_len = 0; + + // private: + // used for performance optimization, don't change it + bool normalized = false; + + // used to initialize the params that used for performance optimization + void normalize() { + if (normalized) { + // already normalized + return; + } + normalized = true; + } +}; + +struct MLAParams : public MLAParamsCommon { + // (batch, seq, head, dim): last dimension is contiguous + using Stride = cute::Stride; + + Stride q_stride; + Stride kv_stride; + Stride o_stride; + + // input shapes + int q_len = 0; + int kv_len = 0; +}; + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h new file mode 100644 index 00000000..3406c4a4 --- /dev/null +++ b/src/kernels/attention/mla_ref.h @@ -0,0 +1,71 @@ +#pragma once + +#include + +namespace llm { +// Multi-head latten attention implementation using pytorch +inline torch::Tensor mla_batch_ref( + torch::Tensor query, // [batch_size, q_len, n_heads, head_dim] + torch::Tensor key, // [batch_size, kv_len, n_kv_heads, head_dim] + torch::Tensor value, // [batch_size, kv_len, n_kv_heads, head_dim] + torch::optional alibi_slopes, //[n_heads] + float logits_soft_cap, + int32_t sliding_window) { + // const auto q_len = query.size(-3); + // const auto kv_len = key.size(-3); + // const auto n_heads = query.size(-2); + // const auto n_kv_heads = key.size(-2); + // const auto head_dim = query.size(-1); + // assert(kv_len >= q_len); + + // if (n_heads != n_kv_heads) { + // assert(n_heads % n_kv_heads == 0); + // const auto group_size = n_heads / n_kv_heads; + // key = key.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2); + // value = value.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2); + // } + + // const float sm_scale = 1.0 / sqrt(head_dim); + // // query * key => [n_heads, q_seq_len, seq_len] + // auto scores = torch::einsum("bqhd,bkhd->bhqk", + // {query.to(torch::kFloat), key.to(torch::kFloat)}); + // // apply scale + // scores *= sm_scale; + + // // apply softcap if needed + // if (logits_soft_cap != 0.0) { + // scores = torch::tanh(scores / logits_soft_cap) * logits_soft_cap; + // } + + // // apply alibi bias + // if (alibi_slopes) { + // const auto& slopes = alibi_slopes.value(); + // // calculate alibi attention bias + // // since it's causal mask, we can just use [0, 1, ...,, kv_len) + // auto distance = torch::arange(0, kv_len, query.options()); + // // [n_heads, 1, kv_len] + // auto bias = distance.view({1, 1, kv_len}) * slopes.view({n_heads, 1, 1}); + // scores += bias; + // } + + // auto mask = torch::ones({q_len, kv_len}, torch::kBool); + // if (sliding_window >= 0) { + // // sliding window mask + // // returns the upper triangular part of a matrix + // mask = torch::triu(mask, /*diagonal=*/kv_len - q_len - sliding_window); + // } + + // // apply causal mask + // // causal mask: returns the lower triangular part of a matrix + // mask = torch::tril(mask, /*diagonal=*/kv_len - q_len).to(query); + // scores = scores.masked_fill(mask == 0, -INFINITY); + + // // safe softmax + // scores = torch::softmax(scores, /*dim=*/-1); + + // // score * value => [batch_size, q_len, n_heads, head_dim] + // return torch::einsum("bhqk,bkhd->bqhd", {scores, value.to(torch::kFloat)}) + // .type_as(query); +} + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h new file mode 100644 index 00000000..d3d0ac06 --- /dev/null +++ b/src/kernels/attention/mla_tile.h @@ -0,0 +1,82 @@ +#pragma once +#include +#include + +#include "gather_tensor.hpp" +#include "mla_params.h" + +namespace llm { +using namespace cute; + +template +struct MLATile { + static_assert(cute::dependent_false, "not implemented"); +}; + +// AttentionTile specialization for AttentionParams +template <> +struct MLATile { + // NOLINTNEXTLINE + const MLAParams& params_; + + CUTE_HOST_DEVICE MLATile(const MLAParams& params) : params_(params) {} + + // return the query/output tile: (q_len, head_dim) + template + CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int kv_head_idx) const { + // (batch, seq, head, dim) + + // packed all q/o in the same kv head group together + // q/o [batch, n_tokens, n_heads, dim] + // => q/o [*batch_idx, n_tokens, n_heads, dim] + // => q/o [n_tokens, group_size, n_kv_heads, dim] + // => q/o [n_tokens, group_size, *kv_head_idx, dim] + // => q/o [(group_size, n_tokens), dim] + // => q/o [packed_len, dim] + const auto group_size = params_.n_heads; + const auto head_base = kv_head_idx * group_size; + auto packed_idx_to_coord = [group_size, head_base](int packed_idx) { + const int idx = packed_idx / group_size; + const int offset = packed_idx % group_size; + // (group_size, n_tokens) + return make_coord(head_base + offset, idx); + }; + + const auto packed_len = params_.q_len * group_size; + const auto q_offset = batch_idx * get<0>(params_.q_stride); + auto q = make_gather_tensor( + make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(packed_len, + params_.qk_nope_head_dim + params_.qk_rope_head_dim), + make_stride( + make_stride(get<2>(params_.q_stride), get<1>(params_.q_stride)), + _1{}), + packed_idx_to_coord); + + const auto o_offset = batch_idx * get<0>(params_.o_stride); + auto o = make_gather_tensor( + make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(packed_len, params_.qk_nope_head_dim), + make_stride( + make_stride(get<2>(params_.o_stride), get<1>(params_.o_stride)), + _1{}), + packed_idx_to_coord); + return make_tuple(q, o); + } + + // return the key/value tile: (kv_len, head_dim) + template + CUTE_HOST_DEVICE auto get_kv_tile(int batch_idx, int kv_head_idx) const { + // (batch, seq, kv_head, dim) + const auto kv_offset = batch_idx * get<0>(params_.kv_stride) + + kv_head_idx * get<2>(params_.kv_stride); + // k[batch_idx, :, kv_head_idx, :] + auto kv = + make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + kv_offset), + make_shape(params_.kv_len, params_.v_head_dim), + make_stride(get<1>(params_.kv_stride), _1{})); + return kv; + } +}; + +} // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h new file mode 100644 index 00000000..19508fbc --- /dev/null +++ b/src/kernels/attention/mla_traits_sm80.h @@ -0,0 +1,145 @@ +#pragma once +#include +#include + +namespace llm { +using namespace cute; + +namespace detail { + +// Convert fragment layout for different purposes +// Only works for TiledMMA (64x16x16) with SM80_16x8x16_F32F16F16F32_TN +struct LayoutConvertor { + // Convert fragment layout to rowcol layout for iterating + // (MMA=4, MMA_M, MMA_N) => ((2, MMA_M), (2, MMA_N)) + template + CUTE_HOST_DEVICE static constexpr auto to_rowcol(const LayoutC& layout) { + auto l = logical_divide(layout, Shape<_2>{}); + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), + make_layout(get<0, 0>(l), get<2>(l))); + } + + // Convert fragment layout from gemm-I C to gemm-II A + // (MMA_C=4,MMA_M,MMA_N) => (MMA_A=(4, 2), MMA_M, MMA_N/2) + template + CUTE_HOST_DEVICE static constexpr auto to_mma_a(const LayoutC& layout) { + auto l = logical_divide(layout.layout(), Shape{}); + return make_layout( + make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +} // namespace detail + +template +struct MLATraitsSM80 { + // helpful aliases + static constexpr int kHeadDim = HEAD_DIM; + static constexpr int kBlockM = BLK_M; + static constexpr int kBlockN = BLK_N; + static constexpr int kBlockK = BLK_K; + static constexpr int kRowsPerMMA = 2; + + static_assert(kHeadDim % kBlockK == 0); + + using DType = DTYPE; + using _BLK_M = Int; + using _BLK_N = Int; + using _BLK_K = Int; + using _HEAD_DIM = Int; + + // ******* Mainloop ******* + // TiledMMA (64x16x16) for gemm-I and gemm-II + // choose MMA_Atom based on Element type + using MMA_Atom_ = + std::conditional_t, + MMA_Atom, + MMA_Atom>; + using TiledMma = TiledMMA>, // warp layout 4x1x1 + Tile<_64, _16, _16>>; // Prom Shape 64x16x16 + + // Layout convertor for TiledMMA (64x16x16) + using LayoutConvertor = detail::LayoutConvertor; + + // SMEM layout for QKV + // Atom layout: (8, BLK_K):(BLK_K, 1) k-major + using SmemLayoutAtom = + decltype(composition(Swizzle<3, 3, 3>{}, + Layout, Stride<_BLK_K, _1>>{})); + + // Q smem: (BLK_M, HEAD_DIM) + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); + + // KV smem: (BLK_N, HEAD_DIM) + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); + + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); + + // V^T smem: (HEAD_DIM, BLK_N) row-major + using SmemLayoutVt = decltype(composition( + SmemLayoutV{}, + make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); + + // Thr layout for gmem copy + using GmemCopyThrLayout = + std::conditional_t, Stride<_4, _1>>, + Layout, Stride<_8, _1>>>; + + // Tiled copy for QKV + // g2s tiled copy for q + using GmemTiledCopyQ = decltype(make_tiled_copy( + Copy_Atom, DType>{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // g2s tiled copy for kv + using GmemTiledCopyKV = GmemTiledCopyQ; + + // s2r tiled copy for gemm-I + using SmemTiledCopyQ = + decltype(make_tiled_copy_A(Copy_Atom{}, + TiledMma{})); + using SmemTiledCopyK = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // s2r tiled copy for gemm-II + using SmemTiledCopyVt = + decltype(make_tiled_copy_B(Copy_Atom{}, + TiledMma{})); + + // ******* Epilogue ******* + + // O smem: (BLK_M, K):(K, 1), k-major, same as Q + using SmemLayoutO = SmemLayoutQ; + + // use 128-bit vectorizing copy + using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + + // s2g tiled copy for O + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom{}, + GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) + Layout>{} // Val layout: 8 vals per read + )); + + // r2s tiled copy for O + using SmemTiledCopyO = + decltype(make_tiled_copy_C(Copy_Atom{}, + TiledMma{})); + + // constexpr values for kernel launch + static constexpr size_t kSmemSize = + (cosize(SmemLayoutQ{}) + cosize(SmemLayoutK{}) + cosize(SmemLayoutV{})) * + sizeof(DType); + + static constexpr size_t kThreadNum = size(TiledMma{}); +}; + +} // namespace llm \ No newline at end of file From f1c4b862684b16f7041657cfc6ba1dae8d17c069 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 4 Feb 2025 16:37:37 -0800 Subject: [PATCH 2/9] added max dynamic smem size for different sm --- src/kernels/attention/mha_dispatch_sm80.cuh | 10 ++++++++++ src/kernels/attention/tools/CMakeLists.txt | 1 + src/kernels/attention/tools/mha_traits_viewer.cpp | 5 +++++ 3 files changed, 16 insertions(+) diff --git a/src/kernels/attention/mha_dispatch_sm80.cuh b/src/kernels/attention/mha_dispatch_sm80.cuh index 7ba20b39..314cbb6b 100644 --- a/src/kernels/attention/mha_dispatch_sm80.cuh +++ b/src/kernels/attention/mha_dispatch_sm80.cuh @@ -46,6 +46,16 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) { params.normalize(); // TODO: tune block shape MNK based on the head dim and smem size + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability + // SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0| + // Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 | + // valid dynamic shared memory sizes for different compute capabilities: + // * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96 + // * 7.5 : 0, 32, 64 + // * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164 + // * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100 + // * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228 + // * 12.0 : 0, 8, 16, 32, 64, 100 if constexpr (HEAD_DIM == 64) { using Traits = MHATraitsSM80 #include "../mha_traits_sm80.h" +#include "common/pretty_print.h" #include "print_svg.hpp" using namespace cute; @@ -33,6 +34,10 @@ void print_attn_traits() { using SmemTiledCopyK = typename Traits::SmemTiledCopyK; using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + // print dynamic smem size + print("Dynamic Smem Size: "); + print(readable_size(Traits::kSmemSize).c_str()); + print("\n"); // print tiled mma print("TiledMma: \n"); From b46120d7a3f94d71efa1548a35d5af6b926afaa5 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 4 Feb 2025 18:07:46 -0800 Subject: [PATCH 3/9] added mla_ref implementation --- src/kernels/attention/mla_kernel_sm80_test.cu | 149 +++++++++--------- src/kernels/attention/mla_ref.h | 86 +++------- 2 files changed, 103 insertions(+), 132 deletions(-) diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index 6f4ca917..db6fa332 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -2,8 +2,8 @@ #include #include +#include -#include "cute/layout.hpp" #include "mla_kernel_sm80.cuh" // IWYU pragma: keep #include "mla_params.h" #include "mla_ref.h" @@ -12,59 +12,59 @@ namespace llm { namespace { torch::Tensor mla_sm80( - torch::Tensor query, // [batch_size, q_len, n_heads, head_dim] - torch::Tensor key, // [batch_size, kv_len, n_kv_heads, head_dim] - torch::Tensor value, // [batch_size, kv_len, n_kv_heads, head_dim] - torch::optional alibi_slopes, //[n_heads] - float logits_soft_cap, - int32_t sliding_window, - int32_t max_q_len) { - const auto batch_size = query.size(0); - const auto q_len = query.size(-3); - const auto kv_len = key.size(-3); - const auto n_heads = query.size(-2); - const auto n_kv_heads = key.size(-2); - const auto head_dim = query.size(-1); - - auto out = torch::empty_like(query); - - const float sm_scale = 1.0 / sqrt(head_dim); - - // construct attention params - MLAParams params; - params.q_ptr = query.const_data_ptr(); - params.q_stride = - make_stride(query.stride(0), query.stride(1), query.stride(2)); - params.k_ptr = key.const_data_ptr(); - params.kv_stride = make_stride(key.stride(0), key.stride(1), key.stride(2)); - - params.o_ptr = out.mutable_data_ptr(); - params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); - - params.batch_size = batch_size; - params.max_q_len = max_q_len; - params.n_heads = n_heads; - params.q_len = q_len; - params.kv_len = kv_len; - // params.head_dim = head_dim; - - // DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { - // DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - // run_mha_kernel_sm80(params); - // }); - // }); + torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] + torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] + torch::Tensor kv, // [batch, kv_len, kv_lora_rank] + torch::Tensor k_rope, // [batch, kv_len, qk_rope_head_dim] + float sm_scale) { + // const auto batch_size = query.size(0); + // const auto q_len = query.size(-3); + // const auto kv_len = key.size(-3); + // const auto n_heads = query.size(-2); + // const auto n_kv_heads = key.size(-2); + // const auto head_dim = query.size(-1); + + auto out = torch::empty_like(q); + + // const float sm_scale = 1.0 / sqrt(head_dim); + + // // construct attention params + // MLAParams params; + // params.q_ptr = query.const_data_ptr(); + // params.q_stride = + // make_stride(query.stride(0), query.stride(1), query.stride(2)); + // params.k_ptr = key.const_data_ptr(); + // params.kv_stride = make_stride(key.stride(0), key.stride(1), + // key.stride(2)); + + // params.o_ptr = out.mutable_data_ptr(); + // params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + // params.batch_size = batch_size; + // params.max_q_len = max_q_len; + // params.n_heads = n_heads; + // params.q_len = q_len; + // params.kv_len = kv_len; + // // params.head_dim = head_dim; + + // // DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { + // // DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + // // run_mha_kernel_sm80(params); + // // }); + // // }); return out; } } // namespace -class MLAKernelTest - : public ::testing::TestWithParam> { +class MLAKernelTest : public ::testing::TestWithParam< + std::tuple> { public: void SetUp() override { // Set random seed for test stability @@ -73,29 +73,35 @@ class MLAKernelTest }; TEST_P(MLAKernelTest, MLA) { - const auto [dtype, batch_size, q_len, kv_len, n_heads, head_dim] = GetParam(); - + const auto [dtype, + batch_size, + q_len, + kv_len, + n_heads, + kv_lora_rank, + qk_rope_head_dim] = GetParam(); + const auto head_dim = kv_lora_rank + qk_rope_head_dim; const auto options = torch::dtype(dtype).device(torch::kCUDA); - // construct non-contiguous query, key and value - // const auto data = torch::randn( - // {batch_size, q_len, n_heads + 2 * n_kv_heads, head_dim}, options); - // const auto qkv = - // data.split(/*split_size=*/{n_heads, n_kv_heads, n_kv_heads}, /*dim=*/2); - // const auto& query = qkv[0]; - // const auto& key = qkv[1]; - // const auto& value = qkv[2]; - - // auto ref_out = mla_batch_ref( - // query, key, value, alibi_slopes, logits_soft_cap, sliding_window); - // auto out = mla_sm80( - // query, key, value, alibi_slopes, logits_soft_cap, sliding_window, q_len); - - // if (dtype == torch::kBFloat16) { - // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2)); - // } else { - // EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); - // } + const auto q = + torch::randn({batch_size, q_len, n_heads, kv_lora_rank}, options); + const auto q_rope = + torch::randn({batch_size, q_len, n_heads, qk_rope_head_dim}, options); + + const auto kv = torch::randn({batch_size, kv_len, kv_lora_rank}, options); + const auto k_rope = + torch::randn({batch_size, kv_len, qk_rope_head_dim}, options); + + const float sm_scale = 1.0 / sqrt(head_dim); + + auto ref_out = mla_batch_ref(q, q_rope, kv, k_rope, sm_scale); + auto out = mla_sm80(q, q_rope, kv, k_rope, sm_scale); + + if (dtype == torch::kBFloat16) { + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2)); + } else { + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); + } } INSTANTIATE_TEST_SUITE_P( @@ -106,7 +112,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(64), // q_len ::testing::Values(64), // kv_len ::testing::Values(8), // n_heads - ::testing::Values(64) // head_dim + ::testing::Values(64), // kv_lora_rank + ::testing::Values(64) // qk_rope_head_dim )); } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h index 3406c4a4..0a48e4dd 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/mla_ref.h @@ -4,68 +4,32 @@ namespace llm { // Multi-head latten attention implementation using pytorch +// reference implementation: +// https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L477 inline torch::Tensor mla_batch_ref( - torch::Tensor query, // [batch_size, q_len, n_heads, head_dim] - torch::Tensor key, // [batch_size, kv_len, n_kv_heads, head_dim] - torch::Tensor value, // [batch_size, kv_len, n_kv_heads, head_dim] - torch::optional alibi_slopes, //[n_heads] - float logits_soft_cap, - int32_t sliding_window) { - // const auto q_len = query.size(-3); - // const auto kv_len = key.size(-3); - // const auto n_heads = query.size(-2); - // const auto n_kv_heads = key.size(-2); - // const auto head_dim = query.size(-1); - // assert(kv_len >= q_len); - - // if (n_heads != n_kv_heads) { - // assert(n_heads % n_kv_heads == 0); - // const auto group_size = n_heads / n_kv_heads; - // key = key.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2); - // value = value.repeat_interleave(/*repeats=*/group_size, /*dim=*/-2); - // } - - // const float sm_scale = 1.0 / sqrt(head_dim); - // // query * key => [n_heads, q_seq_len, seq_len] - // auto scores = torch::einsum("bqhd,bkhd->bhqk", - // {query.to(torch::kFloat), key.to(torch::kFloat)}); - // // apply scale - // scores *= sm_scale; - - // // apply softcap if needed - // if (logits_soft_cap != 0.0) { - // scores = torch::tanh(scores / logits_soft_cap) * logits_soft_cap; - // } - - // // apply alibi bias - // if (alibi_slopes) { - // const auto& slopes = alibi_slopes.value(); - // // calculate alibi attention bias - // // since it's causal mask, we can just use [0, 1, ...,, kv_len) - // auto distance = torch::arange(0, kv_len, query.options()); - // // [n_heads, 1, kv_len] - // auto bias = distance.view({1, 1, kv_len}) * slopes.view({n_heads, 1, 1}); - // scores += bias; - // } - - // auto mask = torch::ones({q_len, kv_len}, torch::kBool); - // if (sliding_window >= 0) { - // // sliding window mask - // // returns the upper triangular part of a matrix - // mask = torch::triu(mask, /*diagonal=*/kv_len - q_len - sliding_window); - // } - - // // apply causal mask - // // causal mask: returns the lower triangular part of a matrix - // mask = torch::tril(mask, /*diagonal=*/kv_len - q_len).to(query); - // scores = scores.masked_fill(mask == 0, -INFINITY); - - // // safe softmax - // scores = torch::softmax(scores, /*dim=*/-1); - - // // score * value => [batch_size, q_len, n_heads, head_dim] - // return torch::einsum("bhqk,bkhd->bqhd", {scores, value.to(torch::kFloat)}) - // .type_as(query); + torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] + torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] + torch::Tensor kv, // [batch, kv_len, kv_lora_rank] + torch::Tensor k_rope, // [batch, kv_len, qk_rope_head_dim] + float sm_scale) { + const auto q_len = q.size(-3); + const auto n_heads = q.size(-2); + const auto kv_len = kv.size(-2); + const auto kv_lora_rank = kv.size(-1); + const auto qk_rope_head_dim = q_rope.size(-1); + assert(kv_len >= q_len); + + // query * key => [batch, q_len, n_heads, kv_len] + auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + + torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); + // apply scale + scores *= sm_scale; + + // safe softmax + scores = scores.softmax(/*dim=*/-1, /*dtype=*/torch::kFloat).type_as(q); + + // score * value => [batch_size, q_len, n_heads, kv_lora_rank] + return torch::einsum("bqhk,bkr->bqhr", {scores, kv}); } } // namespace llm \ No newline at end of file From 613821eb13e5e9948c7177e5d3c8a30163f36fed Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 4 Feb 2025 18:07:46 -0800 Subject: [PATCH 4/9] added mla_ref implementation --- src/kernels/attention/mla_ref.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h index 0a48e4dd..cb206747 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/mla_ref.h @@ -3,7 +3,7 @@ #include namespace llm { -// Multi-head latten attention implementation using pytorch +// Multi-head latent attention implementation using pytorch // reference implementation: // https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L477 inline torch::Tensor mla_batch_ref( From b0e54981181e2b1359c29e5797a706223279e88b Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 7 Feb 2025 18:52:07 -0800 Subject: [PATCH 5/9] simple Q*K^T*V to profile huge head_dim perf --- src/kernels/attention/CMakeLists.txt | 9 + src/kernels/attention/mla_kernel_sm80.cuh | 254 ++++-------------- src/kernels/attention/mla_kernel_sm80_test.cu | 120 +++++---- src/kernels/attention/mla_params.h | 25 +- src/kernels/attention/mla_ref.h | 15 +- src/kernels/attention/mla_sm80_bench.cu | 70 +++++ src/kernels/attention/mla_tile.h | 56 ++-- src/kernels/attention/mla_traits_sm80.h | 27 +- 8 files changed, 243 insertions(+), 333 deletions(-) create mode 100644 src/kernels/attention/mla_sm80_bench.cu diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index b3dc5aec..d4273f73 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -102,4 +102,13 @@ nvbench_binary( :attention.template ) +nvbench_binary( + NAME + mla_sm80_bench + SRCS + mla_sm80_bench.cu + DEPS + :attention.template +) + add_subdirectory(tools) \ No newline at end of file diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index a5145da9..36c9d398 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -28,11 +28,13 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { constexpr int kBlockM = Traits::kBlockM; constexpr int kBlockN = Traits::kBlockN; constexpr int kHeadDim = Traits::kHeadDim; - constexpr int kRowsPerMMA = Traits::kRowsPerMMA; + constexpr int kRopeHeadDim = Traits::kRopeHeadDim; + // constexpr int kRowsPerMMA = Traits::kRowsPerMMA; using _BLK_M = Int; using _BLK_N = Int; using _HEAD_DIM = Int; + using _ROPE_HEAD_DIM = Int; // type alias using DType = typename Traits::DType; @@ -41,10 +43,10 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { using Layout = typename Traits::LayoutConvertor; using SmemLayoutQ = typename Traits::SmemLayoutQ; - using SmemLayoutK = typename Traits::SmemLayoutK; - using SmemLayoutV = typename Traits::SmemLayoutV; + using SmemLayoutKV = typename Traits::SmemLayoutKV; using SmemLayoutVt = typename Traits::SmemLayoutVt; using SmemLayoutO = typename Traits::SmemLayoutO; + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; using GmemTiledCopyO = typename Traits::GmemTiledCopyO; @@ -56,35 +58,28 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { const int m_block = blockIdx.x; const int batch_idx = blockIdx.y; - const int kv_head_idx = blockIdx.z; const int tidx = threadIdx.x; MLATile tile(params); - // preprocess input parameters - const int head_dim = params.head_dim; - const int group_size = params.group_size; - const float logits_soft_cap = params.logits_soft_cap; - const float sm_scale = params.sm_scale; - const float sm_scale_log2 = params.sm_scale_log2; - // ProblemShape - // (q_packed_len, HEAD_DIM) - auto [Q, O] = tile.template get_qo_tile(batch_idx, kv_head_idx); - // (kv_len, HEAD_DIM) - auto [K, V] = tile.template get_kv_tile(batch_idx, kv_head_idx); + // Q/O: (q_packed_len, HEAD_DIM) + auto [Q, O] = tile.template get_qo_tile(batch_idx); + // KV: (kv_len, HEAD_DIM) + auto KV = tile.template get_kv_tile(batch_idx); + + // Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM) + // auto [Q_ROPE, K_ROPE] = tile.template get_qk_rope_tile(batch_idx); const int q_packed_len = size<0>(Q); - const int q_len = q_packed_len / group_size; - const int kv_len = size<0>(K); + // const int q_len = q_packed_len / group_size; + const int kv_len = size<0>(KV); if (m_block * kBlockM >= q_packed_len) { // m out of bound, return return; } - const int sliding_window = LOCAL ? params.sliding_window : kv_len; - // Gmem // (BLK_M, HEAD_DIM) Tensor gQ = @@ -92,24 +87,21 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { Tensor gO = local_tile(O, Shape<_BLK_M, _HEAD_DIM>{}, make_coord(m_block, _0{})); // (BLK_N, HEAD_DIM, n) - Tensor gK = local_tile(K, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); - Tensor gV = local_tile(V, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); + Tensor gKV = local_tile(KV, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); // Smem extern __shared__ char smem[]; DType* q_smem = (DType*)smem; - DType* k_smem = q_smem + cosize(SmemLayoutQ{}); - DType* v_smem = k_smem + cosize(SmemLayoutK{}); + DType* kv_smem = q_smem + cosize(SmemLayoutQ{}); // (BLK_M, HEAD_DIM), k-major Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); // (BLK_N, HEAD_DIM), k-major - Tensor sK = make_tensor(make_smem_ptr(k_smem), SmemLayoutK{}); - Tensor sV = make_tensor(make_smem_ptr(v_smem), SmemLayoutV{}); + Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); // Tensor for V^t; used in GEMM-II. // (HEAD_DIM, BLK_N), m-major - Tensor sVt = make_tensor(make_smem_ptr(v_smem), SmemLayoutVt{}); + Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); // Tiled Copy // g2s tiled copy for qkv @@ -118,55 +110,16 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx); auto gmem_thr_copy_KV = gmem_tiled_copy_KV.get_thread_slice(tidx); - // coordinate tensor for oob handling - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); - Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); - - auto produce_query = [&]() { + auto produce_q = [&]() { auto tQgQ = gmem_thr_copy_Q.partition_S(gQ); auto tQsQ = gmem_thr_copy_Q.partition_D(sQ); - auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim); - safe_copy( - gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, max_coord); - }; - - // (BLK_N, HEAD_DIM) -> (blk_n, head_dim) - Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{}); - Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); - - Tensor tKsK = gmem_thr_copy_KV.partition_D(sK); - auto produce_key = [&](int ni) { - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skip ZFILL_MN for key since Mask will mask out oob with -inf - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); - }; - - // produce key without oob handling - auto produce_key_no_oob = [&](int ni) { - auto tKgK = gmem_thr_copy_KV.partition_S(gK(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( - gmem_tiled_copy_KV, tKgK, tKsK, tKVcKV, max_coord); - }; - - Tensor tVsV = gmem_thr_copy_KV.partition_D(sV); - auto produce_value = [&](int ni) { - auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - // skipping ZFILL_MN for v may cause nan issue - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + cute::copy(gmem_tiled_copy_Q, tQgQ, tQsQ); }; - // produce value without oob handling - auto produce_value_no_oob = [&](int ni) { - auto tVgV = gmem_thr_copy_KV.partition_S(gV(_, _, ni)); - auto max_coord = make_coord(kv_len - ni * kBlockN, head_dim); - safe_copy( - gmem_tiled_copy_KV, tVgV, tVsV, tKVcKV, max_coord); + Tensor tKsKV = gmem_thr_copy_KV.partition_D(sK); + auto produce_kv = [&](int ni) { + auto tKgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni)); + cute::copy(gmem_tiled_copy_KV, tKgKV, tKsKV); }; TiledMma tiled_mma; @@ -264,20 +217,12 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - // (BLK_M, HEAD_DIM) -> (blk_m, head_dim) - auto cO = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{}); - auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) - // (CPY,CPY_M,CPY_K) -> (blk_m, head_dim) - auto tOcO = gmem_thr_copy_O.partition_D(cO); // wait for smem copy done before gmem copy __syncthreads(); - - auto max_coord = make_coord(q_packed_len - m_block * kBlockM, head_dim); - safe_copy( - gmem_tiled_copy_O, tOsO, tOgO, tOcO, max_coord); + cute::copy(gmem_tiled_copy_O, tOsO, tOgO); }; // output accumulator, (MMA,MMA_M,MMA_K) @@ -286,162 +231,57 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout())); clear(tOrAccO); - const int diagonal = (m_block * kBlockM) / group_size + kv_len - q_len; - // process kv in range: [kv_idx_min, kv_idx_max) - const int kv_idx_min = std::max(0, diagonal - sliding_window); - const int kv_idx_max = std::min(kv_len, diagonal + kBlockM); - const int n_block_min = LOCAL ? kv_idx_min / kBlockN : 0; - const int n_block_max = cute::ceil_div(kv_idx_max, kBlockN); - - if (n_block_min >= n_block_max) { - // write output to gmem - epilogue(tOrAccO); - return; - } + const int n_block_min = 0; + const int n_block_max = cute::ceil_div(kv_len, kBlockN); // ############### Prologue ############### - int n_block_idx = n_block_max - 1; // produce query: [] => [q] - produce_query(); + produce_q(); cp_async_fence(); - // produce key: [q] => [q, k] - produce_key(n_block_idx); + // produce key: [q] => [q, kv] + produce_kv(0); cp_async_fence(); // ############### Mainloop ############### - // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); - auto tSrAccS_rc_view = - make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); - - auto apply_logits_soft_cap = [&](auto& tSrAccS) { - if constexpr (SOFT_CAP) { - CUTE_UNROLL - for (int i = 0; i < size(tSrAccS); ++i) { - tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap); - } - } - }; - - constexpr int kMMA_M = size<1>(tSrAccS); - using Softmax = OnlineSoftmax; - using Mask = Mask; - - Softmax softmax(sm_scale_log2); - Mask mask(tidx, - m_block, - q_len, - kv_len, - kv_head_idx, - group_size, - sliding_window, - sm_scale, - params.alibi_slopes_ptr); - - // seperate oob mask iterations for better performance - constexpr int n_oob_mask = cute::ceil_div(kBlockM, kBlockN) + 1; - - // oob mask iterations - CUTE_UNROLL - for (int i = 0; i < n_oob_mask; ++i) { + CUTE_NO_UNROLL + for (int ni = n_block_min; ni < n_block_max; ++ni) { + // attention score accumulator, (MMA,MMA_M,MMA_N) + auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrAccS_rc_view = + make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); clear(tSrAccS); - // wait key, queue: [q, k] => [] + // wait key, queue: [q, kv] => [] cp_async_wait<0>(); __syncthreads(); - // produce value, [] => [v] - if (i == 0) { - produce_value(n_block_idx); - } else { - produce_value_no_oob(n_block_idx); - } - cp_async_fence(); - // 1> S = Q@K.T compute_qk(tSrAccS); - if constexpr (SOFT_CAP) { - apply_logits_soft_cap(tSrAccS); - } - mask.apply(tSrAccS_rc_view, n_block_idx); - softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - - // wait value, [v] => [] - cp_async_wait<0>(); - __syncthreads(); - - // produce next key: [] => [k] - if (n_block_idx > n_block_min) { - produce_key_no_oob(n_block_idx - 1); - } - cp_async_fence(); - // 2> O = softmax(S)*V compute_sv(tSrAccS, tOrAccO); - --n_block_idx; - if (n_block_idx < n_block_min) { - // no more kv blocks to process - break; - } - } - - // non-oob mask iterations - CUTE_NO_UNROLL - for (; n_block_idx >= n_block_min; --n_block_idx) { - clear(tSrAccS); - - // wait key, queue: [q, k] => [] - cp_async_wait<0>(); - __syncthreads(); - - // produce value, [] => [v] - produce_value_no_oob(n_block_idx); - cp_async_fence(); - - // 1> S = Q@K.T - compute_qk(tSrAccS); - - if constexpr (SOFT_CAP) { - apply_logits_soft_cap(tSrAccS); - } - mask.apply(tSrAccS_rc_view, n_block_idx); - softmax.rescale(tSrAccS_rc_view, tOrAccO_rc_view); - - // wait value, [v] => [] - cp_async_wait<0>(); - __syncthreads(); - - // produce next key: [] => [k] - if (n_block_idx > n_block_min) { - produce_key_no_oob(n_block_idx - 1); + // produce next key: [] => [kv] + if (ni != n_block_max - 1) { + produce_kv(ni + 1); } cp_async_fence(); - - // 2> O = softmax(S)*V - compute_sv(tSrAccS, tOrAccO); } // ############### Epilogue ############### - - // normalize output: o /= rowsum - softmax.finalize(tOrAccO_rc_view); - // write output to gmem epilogue(tOrAccO); } template + bool EVEN_K = false, + bool ALIBI = false, + bool SOFT_CAP = false, + bool LOCAL = false> void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { const auto batch_size = params.batch_size; - const auto n_kv_heads = params.n_kv_heads; - const auto max_q_packed_len = params.max_q_len * params.group_size; + const auto max_q_packed_len = params.max_q_len * params.n_heads; const auto smem_size = Traits::kSmemSize; auto mla_kernel = @@ -449,9 +289,7 @@ void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { cudaFuncSetAttribute( mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); // TODO: support persistent kernels - dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), - batch_size, - n_kv_heads); + dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); dim3 block = Traits::kThreadNum; mla_kernel<<>>(params); } diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index db6fa332..51858587 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -4,67 +4,70 @@ #include #include +#include "cute/numeric/numeric_types.hpp" #include "mla_kernel_sm80.cuh" // IWYU pragma: keep #include "mla_params.h" #include "mla_ref.h" +#include "mla_traits_sm80.h" namespace llm { namespace { torch::Tensor mla_sm80( - torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] - torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] - torch::Tensor kv, // [batch, kv_len, kv_lora_rank] - torch::Tensor k_rope, // [batch, kv_len, qk_rope_head_dim] + torch::Tensor q, // [batch, q_len, n_heads, head_dim] + torch::Tensor kv, // [batch, kv_len, head_dim] + torch::Tensor q_rope, // [batch, q_len, n_heads, rope_head_dim] + torch::Tensor k_rope, // [batch, kv_len, rope_head_dim] float sm_scale) { - // const auto batch_size = query.size(0); - // const auto q_len = query.size(-3); - // const auto kv_len = key.size(-3); - // const auto n_heads = query.size(-2); - // const auto n_kv_heads = key.size(-2); - // const auto head_dim = query.size(-1); + const auto batch_size = q.size(0); + const auto q_len = q.size(-3); + const auto kv_len = kv.size(-3); + const auto n_heads = q.size(-2); + const auto head_dim = q.size(-1); + const auto rope_head_dim = q_rope.size(-1); auto out = torch::empty_like(q); - // const float sm_scale = 1.0 / sqrt(head_dim); - - // // construct attention params - // MLAParams params; - // params.q_ptr = query.const_data_ptr(); - // params.q_stride = - // make_stride(query.stride(0), query.stride(1), query.stride(2)); - // params.k_ptr = key.const_data_ptr(); - // params.kv_stride = make_stride(key.stride(0), key.stride(1), - // key.stride(2)); - - // params.o_ptr = out.mutable_data_ptr(); - // params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); - - // params.batch_size = batch_size; - // params.max_q_len = max_q_len; - // params.n_heads = n_heads; - // params.q_len = q_len; - // params.kv_len = kv_len; - // // params.head_dim = head_dim; - - // // DISPATCH_TORCH_DTYPE_(query.dtype(), DTYPE, [&] { - // // DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { - // // run_mha_kernel_sm80(params); - // // }); - // // }); + // construct attention params + MLAParams params; + params.q_ptr = q.const_data_ptr(); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.kv_ptr = kv.const_data_ptr(); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + params.batch_size = batch_size; + params.max_q_len = q_len; + params.n_heads = n_heads; + params.q_len = q_len; + params.kv_len = kv_len; + params.head_dim = head_dim; + params.rope_head_dim = rope_head_dim; + // params.sm_scale = sm_scale; + + using Traits = MLATraitsSM80; + + launch_mla_kernel_sm80(params, nullptr); return out; } } // namespace -class MLAKernelTest : public ::testing::TestWithParam< - std::tuple> { +class MLAKernelTest + : public ::testing::TestWithParam> { public: void SetUp() override { // Set random seed for test stability @@ -78,24 +81,27 @@ TEST_P(MLAKernelTest, MLA) { q_len, kv_len, n_heads, - kv_lora_rank, - qk_rope_head_dim] = GetParam(); - const auto head_dim = kv_lora_rank + qk_rope_head_dim; + head_dim, + rope_head_dim] = GetParam(); + // const auto head_dim = kv_lora_rank + rope_head_dim; const auto options = torch::dtype(dtype).device(torch::kCUDA); - const auto q = - torch::randn({batch_size, q_len, n_heads, kv_lora_rank}, options); - const auto q_rope = - torch::randn({batch_size, q_len, n_heads, qk_rope_head_dim}, options); + // q: [batch, len, n_heads, head_dim] + // kv: [batch, len, head_dim] + const auto q = torch::randn({batch_size, q_len, n_heads, head_dim}, options); + const auto kv = torch::randn({batch_size, kv_len, head_dim}, options); - const auto kv = torch::randn({batch_size, kv_len, kv_lora_rank}, options); + // q_rope: [batch, len, n_heads, rope_head_dim] + // kv_rope: [batch, len, rope_head_dim] + const auto q_rope = + torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options); const auto k_rope = - torch::randn({batch_size, kv_len, qk_rope_head_dim}, options); + torch::randn({batch_size, kv_len, rope_head_dim}, options); const float sm_scale = 1.0 / sqrt(head_dim); - auto ref_out = mla_batch_ref(q, q_rope, kv, k_rope, sm_scale); - auto out = mla_sm80(q, q_rope, kv, k_rope, sm_scale); + auto ref_out = mla_batch_ref(q, kv, q_rope, k_rope, sm_scale); + auto out = mla_sm80(q, kv, q_rope, k_rope, sm_scale); if (dtype == torch::kBFloat16) { EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-2, /*atol=*/1e-2)); @@ -112,8 +118,8 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(64), // q_len ::testing::Values(64), // kv_len ::testing::Values(8), // n_heads - ::testing::Values(64), // kv_lora_rank - ::testing::Values(64) // qk_rope_head_dim + ::testing::Values(64), // head_dim + ::testing::Values(64) // rope_head_dim )); } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h index 8abf1dd4..bcc88d9c 100644 --- a/src/kernels/attention/mla_params.h +++ b/src/kernels/attention/mla_params.h @@ -9,19 +9,20 @@ namespace llm { // common params for attention kernels struct MLAParamsCommon { const void* __restrict__ q_ptr = nullptr; - const void* __restrict__ k_ptr = nullptr; - const void* __restrict__ v_ptr = nullptr; + const void* __restrict__ q_rope_ptr = nullptr; + const void* __restrict__ kv_ptr = nullptr; + const void* __restrict__ k_rope_ptr = nullptr; + void* __restrict__ o_ptr = nullptr; // input shapes int batch_size = 0; + int n_heads = 0; - // int n_kv_heads = 0; - // int head_dim = 0; + int head_dim = 0; + int rope_head_dim = 0; - int qk_nope_head_dim = 0; - int qk_rope_head_dim = 0; - int v_head_dim = 0; + // int v_head_dim = 0; // used for scheduling // TODO: remove it after persistent kernel @@ -42,11 +43,17 @@ struct MLAParamsCommon { }; struct MLAParams : public MLAParamsCommon { - // (batch, seq, head, dim): last dimension is contiguous + // Q/O: (batch, seq, head, dim): last dimension is contiguous using Stride = cute::Stride; + // KV: (batch, seq, dim): last dimension is contiguous + using KV_Stride = cute::Stride; Stride q_stride; - Stride kv_stride; + Stride q_rope_stride; + + KV_Stride kv_stride; + KV_Stride k_rope_stride; + Stride o_stride; // input shapes diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h index cb206747..5e60dcce 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/mla_ref.h @@ -7,9 +7,9 @@ namespace llm { // reference implementation: // https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L477 inline torch::Tensor mla_batch_ref( - torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] - torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] - torch::Tensor kv, // [batch, kv_len, kv_lora_rank] + torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] + torch::Tensor kv, // [batch, kv_len, kv_lora_rank] + torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] torch::Tensor k_rope, // [batch, kv_len, qk_rope_head_dim] float sm_scale) { const auto q_len = q.size(-3); @@ -20,13 +20,14 @@ inline torch::Tensor mla_batch_ref( assert(kv_len >= q_len); // query * key => [batch, q_len, n_heads, kv_len] - auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + - torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); + // auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + + // torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); + auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}); // apply scale - scores *= sm_scale; + // scores *= sm_scale; // safe softmax - scores = scores.softmax(/*dim=*/-1, /*dtype=*/torch::kFloat).type_as(q); + // scores = scores.softmax(/*dim=*/-1, /*dtype=*/torch::kFloat).type_as(q); // score * value => [batch_size, q_len, n_heads, kv_lora_rank] return torch::einsum("bqhk,bkr->bqhr", {scores, kv}); diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu new file mode 100644 index 00000000..f80546ab --- /dev/null +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -0,0 +1,70 @@ +#include +#include + +#include +#include + +#include "mla_kernel_sm80.cuh" // IWYU pragma: keep +#include "mla_params.h" +#include "mla_traits_sm80.h" + +using namespace llm; + +void mla_bench_sm80(nvbench::state& state) { + // Collect CUPTI metrics + state.collect_cupti_metrics(); + + // Get the parameters + const auto batch_size = state.get_int64("batch_size"); + const auto q_len = state.get_int64("q_len"); + const auto kv_len = state.get_int64("kv_len"); + const auto n_heads = state.get_int64("n_heads"); + const auto head_dim = state.get_int64("head_dim"); + const auto rope_head_dim = state.get_int64("rope_head_dim"); + + const auto options = torch::dtype(torch::kHalf).device(torch::kCUDA); + const auto q = torch::randn({batch_size, q_len, n_heads, head_dim}, options); + const auto kv = torch::randn({batch_size, kv_len, head_dim}, options); + + const auto q_rope = + torch::randn({batch_size, q_len, n_heads, rope_head_dim}, options); + const auto k_rope = + torch::randn({batch_size, kv_len, rope_head_dim}, options); + + auto out = torch::empty_like(q); + + // construct attention params + MLAParams params; + params.q_ptr = q.const_data_ptr(); + params.q_stride = make_stride(q.stride(0), q.stride(1), q.stride(2)); + params.kv_ptr = kv.const_data_ptr(); + params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + + params.o_ptr = out.mutable_data_ptr(); + params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); + + params.batch_size = batch_size; + params.max_q_len = q_len; + params.n_heads = n_heads; + params.q_len = q_len; + params.kv_len = kv_len; + params.head_dim = head_dim; + params.rope_head_dim = rope_head_dim; + + using Traits = MLATraitsSM80; + + launch_mla_kernel_sm80(params, nullptr); +} + +NVBENCH_BENCH(mla_bench_sm80) + .add_int64_axis("batch_size", {1}) + .add_int64_axis("q_len", {1024}) + .add_int64_axis("kv_len", {1024}) + .add_int64_axis("n_heads", {8}) + .add_int64_axis("head_dim", {64}) + .add_int64_axis("rope_head_dim", {64}); diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h index d3d0ac06..8d7ffc58 100644 --- a/src/kernels/attention/mla_tile.h +++ b/src/kernels/attention/mla_tile.h @@ -21,59 +21,33 @@ struct MLATile { CUTE_HOST_DEVICE MLATile(const MLAParams& params) : params_(params) {} - // return the query/output tile: (q_len, head_dim) + // return the query/output tile: (q_packed_len, head_dim) template - CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx, int kv_head_idx) const { + CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx) const { // (batch, seq, head, dim) - - // packed all q/o in the same kv head group together - // q/o [batch, n_tokens, n_heads, dim] - // => q/o [*batch_idx, n_tokens, n_heads, dim] - // => q/o [n_tokens, group_size, n_kv_heads, dim] - // => q/o [n_tokens, group_size, *kv_head_idx, dim] - // => q/o [(group_size, n_tokens), dim] - // => q/o [packed_len, dim] - const auto group_size = params_.n_heads; - const auto head_base = kv_head_idx * group_size; - auto packed_idx_to_coord = [group_size, head_base](int packed_idx) { - const int idx = packed_idx / group_size; - const int offset = packed_idx % group_size; - // (group_size, n_tokens) - return make_coord(head_base + offset, idx); - }; - - const auto packed_len = params_.q_len * group_size; + const auto q_packed_len = params_.q_len * params_.n_heads; const auto q_offset = batch_idx * get<0>(params_.q_stride); - auto q = make_gather_tensor( - make_gmem_ptr((const Element*)params_.q_ptr + q_offset), - make_shape(packed_len, - params_.qk_nope_head_dim + params_.qk_rope_head_dim), - make_stride( - make_stride(get<2>(params_.q_stride), get<1>(params_.q_stride)), - _1{}), - packed_idx_to_coord); + auto q = + make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), + make_shape(q_packed_len, params_.head_dim), + make_stride(get<2>(params_.q_stride), _1{})); const auto o_offset = batch_idx * get<0>(params_.o_stride); - auto o = make_gather_tensor( - make_gmem_ptr((Element*)params_.o_ptr + o_offset), - make_shape(packed_len, params_.qk_nope_head_dim), - make_stride( - make_stride(get<2>(params_.o_stride), get<1>(params_.o_stride)), - _1{}), - packed_idx_to_coord); + auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), + make_shape(q_packed_len, params_.head_dim), + make_stride(get<2>(params_.o_stride), _1{})); return make_tuple(q, o); } // return the key/value tile: (kv_len, head_dim) template - CUTE_HOST_DEVICE auto get_kv_tile(int batch_idx, int kv_head_idx) const { - // (batch, seq, kv_head, dim) - const auto kv_offset = batch_idx * get<0>(params_.kv_stride) + - kv_head_idx * get<2>(params_.kv_stride); + CUTE_HOST_DEVICE auto get_kv_tile(int batch_idx) const { + // (batch, seq, dim) + const auto kv_offset = batch_idx * get<0>(params_.kv_stride); // k[batch_idx, :, kv_head_idx, :] auto kv = - make_tensor(make_gmem_ptr((const Element*)params_.k_ptr + kv_offset), - make_shape(params_.kv_len, params_.v_head_dim), + make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), + make_shape(params_.kv_len, params_.head_dim), make_stride(get<1>(params_.kv_stride), _1{})); return kv; } diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h index 19508fbc..cc009b0b 100644 --- a/src/kernels/attention/mla_traits_sm80.h +++ b/src/kernels/attention/mla_traits_sm80.h @@ -31,22 +31,30 @@ struct LayoutConvertor { } // namespace detail -template +template struct MLATraitsSM80 { // helpful aliases static constexpr int kHeadDim = HEAD_DIM; + static constexpr int kRopeHeadDim = ROPE_HEAD_DIM; static constexpr int kBlockM = BLK_M; static constexpr int kBlockN = BLK_N; static constexpr int kBlockK = BLK_K; static constexpr int kRowsPerMMA = 2; static_assert(kHeadDim % kBlockK == 0); + static_assert(kRopeHeadDim % kBlockK == 0); using DType = DTYPE; using _BLK_M = Int; using _BLK_N = Int; using _BLK_K = Int; using _HEAD_DIM = Int; + using _ROPE_HEAD_DIM = Int; // ******* Mainloop ******* // TiledMMA (64x16x16) for gemm-I and gemm-II @@ -73,15 +81,12 @@ struct MLATraitsSM80 { decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); // KV smem: (BLK_N, HEAD_DIM) - using SmemLayoutK = - decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - - using SmemLayoutV = + using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); // V^T smem: (HEAD_DIM, BLK_N) row-major using SmemLayoutVt = decltype(composition( - SmemLayoutV{}, + SmemLayoutKV{}, make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); // Thr layout for gmem copy @@ -93,7 +98,7 @@ struct MLATraitsSM80 { // Tiled copy for QKV // g2s tiled copy for q using GmemTiledCopyQ = decltype(make_tiled_copy( - Copy_Atom, DType>{}, + Copy_Atom, DType>{}, GmemCopyThrLayout{}, // Thr layout: (_16,_8)/(_32, _4) Layout>{} // Val layout: 8 vals per read )); @@ -116,8 +121,9 @@ struct MLATraitsSM80 { // ******* Epilogue ******* - // O smem: (BLK_M, K):(K, 1), k-major, same as Q - using SmemLayoutO = SmemLayoutQ; + // O smem: (BLK_M, HEAD_DIM):(K, 1), k-major + using SmemLayoutO = + decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); // use 128-bit vectorizing copy using VectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; @@ -136,8 +142,7 @@ struct MLATraitsSM80 { // constexpr values for kernel launch static constexpr size_t kSmemSize = - (cosize(SmemLayoutQ{}) + cosize(SmemLayoutK{}) + cosize(SmemLayoutV{})) * - sizeof(DType); + (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{})) * sizeof(DType); static constexpr size_t kThreadNum = size(TiledMma{}); }; From 13140c433965a5a0b0f930808cf38f9e0d21de18 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 7 Feb 2025 20:09:57 -0800 Subject: [PATCH 6/9] update --- src/kernels/attention/CMakeLists.txt | 1 + src/kernels/attention/cute_extensions.cuh | 6 +++ src/kernels/attention/mha_traits_sm80.h | 7 ++- src/kernels/attention/mla_kernel_sm80.cuh | 7 ++- src/kernels/attention/mla_sm80_bench.cu | 47 +++++++++++++++---- src/kernels/attention/mla_traits_sm80.h | 8 ++-- src/kernels/attention/mla_traits_test.cpp | 56 +++++++++++++++++++++++ 7 files changed, 114 insertions(+), 18 deletions(-) create mode 100644 src/kernels/attention/mla_traits_test.cpp diff --git a/src/kernels/attention/CMakeLists.txt b/src/kernels/attention/CMakeLists.txt index d4273f73..d1cf0149 100644 --- a/src/kernels/attention/CMakeLists.txt +++ b/src/kernels/attention/CMakeLists.txt @@ -75,6 +75,7 @@ cc_test( NAME mla_kernel_test SRCS + mla_traits_test.cpp mla_kernel_sm80_test.cu DEPS :attention.template diff --git a/src/kernels/attention/cute_extensions.cuh b/src/kernels/attention/cute_extensions.cuh index 7f490062..36d7bd7a 100644 --- a/src/kernels/attention/cute_extensions.cuh +++ b/src/kernels/attention/cute_extensions.cuh @@ -20,6 +20,12 @@ constexpr bool .with(declval()))>> = true; } // namespace detail +template +CUTE_HOST_DEVICE constexpr auto permute( + const ComposedLayout, Offset, LayoutB>& c) { + return composition(c.layout_a(), c.offset(), select(c.layout_b())); +} + template CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a, IntTupleB const& b) { diff --git a/src/kernels/attention/mha_traits_sm80.h b/src/kernels/attention/mha_traits_sm80.h index 6299d80e..9132413a 100644 --- a/src/kernels/attention/mha_traits_sm80.h +++ b/src/kernels/attention/mha_traits_sm80.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include "cute_extensions.cuh" namespace llm { using namespace cute; @@ -79,10 +80,8 @@ struct MHATraitsSM80 { using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - // V^T smem: (HEAD_DIM, BLK_N) row-major - using SmemLayoutVt = decltype(composition( - SmemLayoutV{}, - make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); + // V^T smem: (HEAD_DIM, BLK_N) + using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutV{})); // Thr layout for gmem copy using GmemCopyThrLayout = diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 36c9d398..2d0e065f 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -284,10 +285,12 @@ void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { const auto max_q_packed_len = params.max_q_len * params.n_heads; const auto smem_size = Traits::kSmemSize; + print("smem_size: %d \n", smem_size); + auto mla_kernel = mla_kernel_sm80; - cudaFuncSetAttribute( - mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + C10_CUDA_CHECK(cudaFuncSetAttribute( + mla_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // TODO: support persistent kernels dim3 grid(cute::ceil_div(max_q_packed_len, Traits::kBlockM), batch_size, 1); dim3 block = Traits::kThreadNum; diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index f80546ab..49f474ff 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -10,6 +10,33 @@ using namespace llm; +#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \ + [&] { \ + if (HEAD_DIM_V <= 64) { \ + constexpr static int HEAD_DIM_NAME = 64; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 128) { \ + constexpr static int HEAD_DIM_NAME = 128; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 256) { \ + constexpr static int HEAD_DIM_NAME = 256; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 384) { \ + constexpr static int HEAD_DIM_NAME = 384; \ + constexpr static int BLK_N = 64; \ + return __VA_ARGS__(); \ + } else if (HEAD_DIM_V <= 512) { \ + constexpr static int HEAD_DIM_NAME = 512; \ + constexpr static int BLK_N = 32; \ + return __VA_ARGS__(); \ + } else { \ + assert(false); \ + } \ + }() + void mla_bench_sm80(nvbench::state& state) { // Collect CUPTI metrics state.collect_cupti_metrics(); @@ -51,14 +78,18 @@ void mla_bench_sm80(nvbench::state& state) { params.head_dim = head_dim; params.rope_head_dim = rope_head_dim; - using Traits = MLATraitsSM80; + state.exec([&](nvbench::launch& launch) { + DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { + using Traits = MLATraitsSM80; - launch_mla_kernel_sm80(params, nullptr); + launch_mla_kernel_sm80(params, launch.get_stream()); + }); + }); } NVBENCH_BENCH(mla_bench_sm80) @@ -66,5 +97,5 @@ NVBENCH_BENCH(mla_bench_sm80) .add_int64_axis("q_len", {1024}) .add_int64_axis("kv_len", {1024}) .add_int64_axis("n_heads", {8}) - .add_int64_axis("head_dim", {64}) + .add_int64_axis("head_dim", {256}) .add_int64_axis("rope_head_dim", {64}); diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h index cc009b0b..5684991a 100644 --- a/src/kernels/attention/mla_traits_sm80.h +++ b/src/kernels/attention/mla_traits_sm80.h @@ -2,6 +2,8 @@ #include #include +#include "cute_extensions.cuh" + namespace llm { using namespace cute; @@ -84,10 +86,8 @@ struct MLATraitsSM80 { using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); - // V^T smem: (HEAD_DIM, BLK_N) row-major - using SmemLayoutVt = decltype(composition( - SmemLayoutKV{}, - make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{}))); + // V^T smem: (HEAD_DIM, BLK_N) + using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutKV{})); // Thr layout for gmem copy using GmemCopyThrLayout = diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp new file mode 100644 index 00000000..330fe0f0 --- /dev/null +++ b/src/kernels/attention/mla_traits_test.cpp @@ -0,0 +1,56 @@ +#include + +#include + +#include "cute_extensions.cuh" +#include "gather_tensor.hpp" +#include "mla_traits_sm80.h" + +namespace llm { + +using namespace cute; + +template +void test_mla_traits() { + // type alias + using TiledMma = typename Traits::TiledMma; + using Layout = typename Traits::LayoutConvertor; + + using SmemLayoutQ = typename Traits::SmemLayoutQ; + using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutVt = typename Traits::SmemLayoutVt; + using SmemLayoutO = typename Traits::SmemLayoutO; + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; + using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; + using GmemTiledCopyO = typename Traits::GmemTiledCopyO; + + using SmemTiledCopyQ = typename Traits::SmemTiledCopyQ; + using SmemTiledCopyK = typename Traits::SmemTiledCopyK; + using SmemTiledCopyVt = typename Traits::SmemTiledCopyVt; + using SmemTiledCopyO = typename Traits::SmemTiledCopyO; + + // test layout conversation + Tensor sQ = make_tensor(counting_iterator(0), SmemLayoutQ{}); + Tensor sKV = make_tensor(counting_iterator(0), SmemLayoutKV{}); + Tensor sVt = make_tensor(sKV.data(), SmemLayoutVt{}); + + // print("sQ:"); print(sQ);print("\n"); + // print("sKV:"); print(sKV);print("\n"); + // print("sVt:"); print(sVt);print("\n"); + + TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_slice(0); + auto tOrVt = thr_mma.partition_fragment_B(sVt); + // TODO: add tests for layout conformance +} + +TEST(MLATraitsTest, TraitsSM80) { + test_mla_traits>(); +} + +} // namespace llm \ No newline at end of file From 619b092ffec23303c681da0e6ed99378f8216d2d Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 7 Feb 2025 21:38:55 -0800 Subject: [PATCH 7/9] added qk_rope part --- src/kernels/attention/mla_kernel_sm80.cuh | 88 ++++++++++++++++--- src/kernels/attention/mla_kernel_sm80_test.cu | 6 ++ src/kernels/attention/mla_ref.h | 13 ++- src/kernels/attention/mla_sm80_bench.cu | 6 ++ src/kernels/attention/mla_tile.h | 20 ++++- src/kernels/attention/mla_traits_sm80.h | 11 ++- src/kernels/attention/mla_traits_test.cpp | 20 +++-- 7 files changed, 139 insertions(+), 25 deletions(-) diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 2d0e065f..79cbfc30 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -45,6 +45,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { using SmemLayoutQ = typename Traits::SmemLayoutQ; using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutQRope = typename Traits::SmemLayoutQRope; + using SmemLayoutKRope = typename Traits::SmemLayoutKRope; using SmemLayoutVt = typename Traits::SmemLayoutVt; using SmemLayoutO = typename Traits::SmemLayoutO; @@ -65,12 +67,10 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // ProblemShape // Q/O: (q_packed_len, HEAD_DIM) - auto [Q, O] = tile.template get_qo_tile(batch_idx); // KV: (kv_len, HEAD_DIM) - auto KV = tile.template get_kv_tile(batch_idx); - // Q/K_ROPE: (q_packed_len, ROPE_HEAD_DIM) - // auto [Q_ROPE, K_ROPE] = tile.template get_qk_rope_tile(batch_idx); + auto [Q, Q_ROPE, O] = tile.template get_qo_tile(batch_idx); + auto [KV, K_ROPE] = tile.template get_kv_tile(batch_idx); const int q_packed_len = size<0>(Q); // const int q_len = q_packed_len / group_size; @@ -90,16 +90,30 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // (BLK_N, HEAD_DIM, n) Tensor gKV = local_tile(KV, Shape<_BLK_N, _HEAD_DIM>{}, make_coord(_, _0{})); + // (BLK_M, ROPE_HEAD_DIM) + Tensor gQ_rope = local_tile( + Q_ROPE, Shape<_BLK_M, _ROPE_HEAD_DIM>{}, make_coord(m_block, _0{})); + // (BLK_N, ROPE_HEAD_DIM, n) + Tensor gK_rope = + local_tile(K_ROPE, Shape<_BLK_N, _ROPE_HEAD_DIM>{}, make_coord(_, _0{})); + // Smem extern __shared__ char smem[]; DType* q_smem = (DType*)smem; DType* kv_smem = q_smem + cosize(SmemLayoutQ{}); + DType* q_rope_smem = kv_smem + cosize(SmemLayoutKV{}); + DType* k_rope_smem = q_rope_smem + cosize(SmemLayoutQRope{}); // (BLK_M, HEAD_DIM), k-major Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{}); // (BLK_N, HEAD_DIM), k-major Tensor sK = make_tensor(make_smem_ptr(kv_smem), SmemLayoutKV{}); + // (BLK_M, ROPE_HEAD_DIM), k-major + Tensor sQ_rope = make_tensor(make_smem_ptr(q_rope_smem), SmemLayoutQRope{}); + // (BLK_N, ROPE_HEAD_DIM), k-major + Tensor sK_rope = make_tensor(make_smem_ptr(k_rope_smem), SmemLayoutKRope{}); + // Tensor for V^t; used in GEMM-II. // (HEAD_DIM, BLK_N), m-major Tensor sVt = make_tensor(make_smem_ptr(kv_smem), SmemLayoutVt{}); @@ -117,17 +131,32 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { cute::copy(gmem_tiled_copy_Q, tQgQ, tQsQ); }; + auto produce_q_rope = [&]() { + auto tQgQ_rope = gmem_thr_copy_Q.partition_S(gQ_rope); + auto tQsQ_rope = gmem_thr_copy_Q.partition_D(sQ_rope); + cute::copy(gmem_tiled_copy_Q, tQgQ_rope, tQsQ_rope); + }; + Tensor tKsKV = gmem_thr_copy_KV.partition_D(sK); auto produce_kv = [&](int ni) { auto tKgKV = gmem_thr_copy_KV.partition_S(gKV(_, _, ni)); cute::copy(gmem_tiled_copy_KV, tKgKV, tKsKV); }; + Tensor tKsK_rope = gmem_thr_copy_KV.partition_D(sK_rope); + auto produce_k_rope = [&](int ni) { + auto tKgK_rope = gmem_thr_copy_KV.partition_S(gK_rope(_, _, ni)); + cute::copy(gmem_tiled_copy_KV, tKgK_rope, tKsK_rope); + }; + TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(tidx); // GEMM-I: S = Q@K.T - auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + auto tSrQ_rope = thr_mma.partition_fragment_A(sQ_rope); // (MMA,MMA_M,MMA_K) + auto tSrK_rope = thr_mma.partition_fragment_B(sK_rope); // (MMA,MMA_N,MMA_K) // s2r tiled copy for qkv SmemTiledCopyQ smem_tiled_copy_Q; @@ -135,11 +164,17 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { auto tSsQ = smem_thr_copy_Q.partition_S(sQ); auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + auto tSsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + auto tSrQ_rope_copy_view = smem_thr_copy_Q.retile_D(tSrQ_rope); + SmemTiledCopyK smem_tiled_copy_K; auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); auto tSsK = smem_thr_copy_K.partition_S(sK); auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + auto tSsK_rope = smem_thr_copy_K.partition_S(sK_rope); + auto tSrK_rope_copy_view = smem_thr_copy_K.retile_D(tSrK_rope); + // S = Q@K.T // tSrAccS: (MMA,MMA_M,MMA_N) auto compute_qk = [&](auto& tSrAccS) { @@ -163,6 +198,31 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { } }; + auto compute_qk_rope = [&](auto& tSrAccS) { + // prefetch qk_rope + cute::copy(smem_tiled_copy_Q, + tSsQ_rope(_, _, _0{}), + tSrQ_rope_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_K, + tSsK_rope(_, _, _0{}), + tSrK_rope_copy_view(_, _, _0{})); + + CUTE_UNROLL + for (int ki = 0; ki < size<2>(tSrQ_rope); ++ki) { + // prefetch next qk_rope + if (ki != size<2>(tSrQ_rope) - 1) { + const auto next_ki = ki + 1; + cute::copy(smem_tiled_copy_Q, + tSsQ_rope(_, _, next_ki), + tSrQ_rope_copy_view(_, _, next_ki)); + cute::copy(smem_tiled_copy_K, + tSsK_rope(_, _, next_ki), + tSrK_rope_copy_view(_, _, next_ki)); + } + cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrAccS); + } + }; + // GEMM-II: O = softmax(S)@V auto tOrVt = thr_mma.partition_fragment_B(sVt); // (MMA,MMA_K,MMA_N) @@ -238,9 +298,13 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // ############### Prologue ############### // produce query: [] => [q] produce_q(); + // produce q_rope: [q] => [q, q_rope] + produce_q_rope(); cp_async_fence(); - // produce key: [q] => [q, kv] + // produce key: [q, q_rope] => [q, q_rope, kv] produce_kv(0); + // produce k_rope: [q, q_rope, kv] => [q, q_rope, kv, k_rope] + produce_k_rope(0); cp_async_fence(); // ############### Mainloop ############### @@ -252,19 +316,23 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); clear(tSrAccS); - // wait key, queue: [q, kv] => [] + // wait key, queue: [q, q_rope, kv, k_rope] => [] cp_async_wait<0>(); __syncthreads(); // 1> S = Q@K.T compute_qk(tSrAccS); - // 2> O = softmax(S)*V + // 2> S = Q@K.T + Q_rope@K_rope.T + compute_qk_rope(tSrAccS); + + // 3> O = softmax(S)*V compute_sv(tSrAccS, tOrAccO); - // produce next key: [] => [kv] + // produce next key: [] => [kv, k_rope] if (ni != n_block_max - 1) { produce_kv(ni + 1); + produce_k_rope(ni + 1); } cp_async_fence(); } diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index 51858587..c0f8cc0c 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -35,6 +35,12 @@ torch::Tensor mla_sm80( params.kv_ptr = kv.const_data_ptr(); params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + params.q_rope_ptr = q_rope.const_data_ptr(); + params.q_rope_stride = + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + params.k_rope_ptr = k_rope.const_data_ptr(); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + params.o_ptr = out.mutable_data_ptr(); params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h index 5e60dcce..9c0f4165 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/mla_ref.h @@ -7,10 +7,10 @@ namespace llm { // reference implementation: // https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py#L477 inline torch::Tensor mla_batch_ref( - torch::Tensor q, // [batch, q_len, n_heads, kv_lora_rank] - torch::Tensor kv, // [batch, kv_len, kv_lora_rank] - torch::Tensor q_rope, // [batch, q_len, n_heads, qk_rope_head_dim] - torch::Tensor k_rope, // [batch, kv_len, qk_rope_head_dim] + torch::Tensor q, // [batch, q_len, n_heads, head_dim] + torch::Tensor kv, // [batch, kv_len, head_dim] + torch::Tensor q_rope, // [batch, q_len, n_heads, rope_head_dim] + torch::Tensor k_rope, // [batch, kv_len, rope_head_dim] float sm_scale) { const auto q_len = q.size(-3); const auto n_heads = q.size(-2); @@ -20,9 +20,8 @@ inline torch::Tensor mla_batch_ref( assert(kv_len >= q_len); // query * key => [batch, q_len, n_heads, kv_len] - // auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + - // torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); - auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}); + auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + + torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); // apply scale // scores *= sm_scale; diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index 49f474ff..7385ffff 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -67,6 +67,12 @@ void mla_bench_sm80(nvbench::state& state) { params.kv_ptr = kv.const_data_ptr(); params.kv_stride = make_stride(kv.stride(0), kv.stride(1)); + params.q_rope_ptr = q_rope.const_data_ptr(); + params.q_rope_stride = + make_stride(q_rope.stride(0), q_rope.stride(1), q_rope.stride(2)); + params.k_rope_ptr = k_rope.const_data_ptr(); + params.k_rope_stride = make_stride(k_rope.stride(0), k_rope.stride(1)); + params.o_ptr = out.mutable_data_ptr(); params.o_stride = make_stride(out.stride(0), out.stride(1), out.stride(2)); diff --git a/src/kernels/attention/mla_tile.h b/src/kernels/attention/mla_tile.h index 8d7ffc58..b6b323c1 100644 --- a/src/kernels/attention/mla_tile.h +++ b/src/kernels/attention/mla_tile.h @@ -22,6 +22,7 @@ struct MLATile { CUTE_HOST_DEVICE MLATile(const MLAParams& params) : params_(params) {} // return the query/output tile: (q_packed_len, head_dim) + // return q_rope tile: (q_packed_len, qk_rope_head_dim) template CUTE_HOST_DEVICE auto get_qo_tile(int batch_idx) const { // (batch, seq, head, dim) @@ -31,12 +32,20 @@ struct MLATile { make_tensor(make_gmem_ptr((const Element*)params_.q_ptr + q_offset), make_shape(q_packed_len, params_.head_dim), make_stride(get<2>(params_.q_stride), _1{})); + + // (batch, seq, head, rope_head_dim) + const auto q_rope_offset = batch_idx * get<0>(params_.q_rope_stride); + auto q_rope = make_tensor( + make_gmem_ptr((const Element*)params_.q_rope_ptr + q_rope_offset), + make_shape(q_packed_len, params_.rope_head_dim), + make_stride(get<2>(params_.q_rope_stride), _1{})); + // (batch, seq, head, dim) const auto o_offset = batch_idx * get<0>(params_.o_stride); auto o = make_tensor(make_gmem_ptr((Element*)params_.o_ptr + o_offset), make_shape(q_packed_len, params_.head_dim), make_stride(get<2>(params_.o_stride), _1{})); - return make_tuple(q, o); + return make_tuple(q, q_rope, o); } // return the key/value tile: (kv_len, head_dim) @@ -49,7 +58,14 @@ struct MLATile { make_tensor(make_gmem_ptr((const Element*)params_.kv_ptr + kv_offset), make_shape(params_.kv_len, params_.head_dim), make_stride(get<1>(params_.kv_stride), _1{})); - return kv; + + // (batch, seq, rope_head_dim) + const auto k_rope_offset = batch_idx * get<0>(params_.k_rope_stride); + auto k_rope = make_tensor( + make_gmem_ptr((const Element*)params_.k_rope_ptr + k_rope_offset), + make_shape(params_.kv_len, params_.rope_head_dim), + make_stride(get<1>(params_.k_rope_stride), _1{})); + return make_tuple(kv, k_rope); } }; diff --git a/src/kernels/attention/mla_traits_sm80.h b/src/kernels/attention/mla_traits_sm80.h index 5684991a..31d2a776 100644 --- a/src/kernels/attention/mla_traits_sm80.h +++ b/src/kernels/attention/mla_traits_sm80.h @@ -82,10 +82,18 @@ struct MLATraitsSM80 { using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_M, _HEAD_DIM>{})); + using SmemLayoutQRope = + decltype(tile_to_shape(SmemLayoutAtom{}, + Shape<_BLK_M, _ROPE_HEAD_DIM>{})); + // KV smem: (BLK_N, HEAD_DIM) using SmemLayoutKV = decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{})); + using SmemLayoutKRope = + decltype(tile_to_shape(SmemLayoutAtom{}, + Shape<_BLK_N, _ROPE_HEAD_DIM>{})); + // V^T smem: (HEAD_DIM, BLK_N) using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutKV{})); @@ -142,7 +150,8 @@ struct MLATraitsSM80 { // constexpr values for kernel launch static constexpr size_t kSmemSize = - (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{})) * sizeof(DType); + sizeof(DType) * (cosize(SmemLayoutQ{}) + cosize(SmemLayoutKV{}) + + cosize(SmemLayoutQRope{}) + cosize(SmemLayoutKRope{})); static constexpr size_t kThreadNum = size(TiledMma{}); }; diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp index 330fe0f0..195a2b51 100644 --- a/src/kernels/attention/mla_traits_test.cpp +++ b/src/kernels/attention/mla_traits_test.cpp @@ -18,8 +18,11 @@ void test_mla_traits() { using SmemLayoutQ = typename Traits::SmemLayoutQ; using SmemLayoutKV = typename Traits::SmemLayoutKV; + using SmemLayoutQRope = typename Traits::SmemLayoutQRope; + using SmemLayoutKRope = typename Traits::SmemLayoutKRope; using SmemLayoutVt = typename Traits::SmemLayoutVt; using SmemLayoutO = typename Traits::SmemLayoutO; + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; using GmemTiledCopyO = typename Traits::GmemTiledCopyO; @@ -34,9 +37,16 @@ void test_mla_traits() { Tensor sKV = make_tensor(counting_iterator(0), SmemLayoutKV{}); Tensor sVt = make_tensor(sKV.data(), SmemLayoutVt{}); - // print("sQ:"); print(sQ);print("\n"); - // print("sKV:"); print(sKV);print("\n"); - // print("sVt:"); print(sVt);print("\n"); + Tensor sQ_rope = make_tensor(counting_iterator(0), SmemLayoutQRope{}); + Tensor sKV_rope = make_tensor(counting_iterator(0), SmemLayoutKRope{}); + + print("sQ:"); print(sQ);print("\n"); + print("sKV:"); print(sKV);print("\n"); + + print("sQ_rope:"); print(sQ_rope);print("\n"); + print("sKV_rope:"); print(sKV_rope);print("\n"); + + print("sVt:"); print(sVt);print("\n"); TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(0); @@ -46,10 +56,10 @@ void test_mla_traits() { TEST(MLATraitsTest, TraitsSM80) { test_mla_traits>(); } From 624f55ac601c1debc68620bbde92486c91fc8c4a Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 7 Feb 2025 22:04:35 -0800 Subject: [PATCH 8/9] rename --- src/kernels/attention/mla_kernel_sm80.cuh | 145 ++++++++---------- src/kernels/attention/mla_kernel_sm80_test.cu | 7 +- 2 files changed, 70 insertions(+), 82 deletions(-) diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 79cbfc30..7716b2b4 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -152,8 +152,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(tidx); // GEMM-I: S = Q@K.T - auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) - auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + auto tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + auto tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) auto tSrQ_rope = thr_mma.partition_fragment_A(sQ_rope); // (MMA,MMA_M,MMA_K) auto tSrK_rope = thr_mma.partition_fragment_B(sK_rope); // (MMA,MMA_N,MMA_K) @@ -161,51 +161,43 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // s2r tiled copy for qkv SmemTiledCopyQ smem_tiled_copy_Q; auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); - auto tSsQ = smem_thr_copy_Q.partition_S(sQ); - auto tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + auto tCsQ = smem_thr_copy_Q.partition_S(sQ); + auto tCrQ = smem_thr_copy_Q.retile_D(tSrQ); - auto tSsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); - auto tSrQ_rope_copy_view = smem_thr_copy_Q.retile_D(tSrQ_rope); + auto tCsQ_rope = smem_thr_copy_Q.partition_S(sQ_rope); + auto tCrQ_rope = smem_thr_copy_Q.retile_D(tSrQ_rope); SmemTiledCopyK smem_tiled_copy_K; auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); - auto tSsK = smem_thr_copy_K.partition_S(sK); - auto tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + auto tCsK = smem_thr_copy_K.partition_S(sK); + auto tCrK = smem_thr_copy_K.retile_D(tSrK); - auto tSsK_rope = smem_thr_copy_K.partition_S(sK_rope); - auto tSrK_rope_copy_view = smem_thr_copy_K.retile_D(tSrK_rope); + auto tCsK_rope = smem_thr_copy_K.partition_S(sK_rope); + auto tCrK_rope = smem_thr_copy_K.retile_D(tSrK_rope); // S = Q@K.T - // tSrAccS: (MMA,MMA_M,MMA_N) - auto compute_qk = [&](auto& tSrAccS) { + // tSrS: (MMA,MMA_M,MMA_N) + auto compute_qk = [&](auto& tSrS) { // prefetch kv - cute::copy(smem_tiled_copy_Q, tSsQ(_, _, _0{}), tSrQ_copy_view(_, _, _0{})); - cute::copy(smem_tiled_copy_K, tSsK(_, _, _0{}), tSrK_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_Q, tCsQ(_, _, _0{}), tCrQ(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK(_, _, _0{}), tCrK(_, _, _0{})); CUTE_UNROLL for (int ki = 0; ki < size<2>(tSrQ); ++ki) { // prefetch next kv if (ki != size<2>(tSrQ) - 1) { const auto next_ki = ki + 1; - cute::copy(smem_tiled_copy_Q, - tSsQ(_, _, next_ki), - tSrQ_copy_view(_, _, next_ki)); - cute::copy(smem_tiled_copy_K, - tSsK(_, _, next_ki), - tSrK_copy_view(_, _, next_ki)); + cute::copy(smem_tiled_copy_Q, tCsQ(_, _, next_ki), tCrQ(_, _, next_ki)); + cute::copy(smem_tiled_copy_K, tCsK(_, _, next_ki), tCrK(_, _, next_ki)); } - cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrAccS); + cute::gemm(tiled_mma, tSrQ(_, _, ki), tSrK(_, _, ki), tSrS); } }; - auto compute_qk_rope = [&](auto& tSrAccS) { + auto compute_qk_rope = [&](auto& tSrS) { // prefetch qk_rope - cute::copy(smem_tiled_copy_Q, - tSsQ_rope(_, _, _0{}), - tSrQ_rope_copy_view(_, _, _0{})); - cute::copy(smem_tiled_copy_K, - tSsK_rope(_, _, _0{}), - tSrK_rope_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_Q, tCsQ_rope(_, _, _0{}), tCrQ_rope(_, _, _0{})); + cute::copy(smem_tiled_copy_K, tCsK_rope(_, _, _0{}), tCrK_rope(_, _, _0{})); CUTE_UNROLL for (int ki = 0; ki < size<2>(tSrQ_rope); ++ki) { @@ -213,13 +205,13 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { if (ki != size<2>(tSrQ_rope) - 1) { const auto next_ki = ki + 1; cute::copy(smem_tiled_copy_Q, - tSsQ_rope(_, _, next_ki), - tSrQ_rope_copy_view(_, _, next_ki)); + tCsQ_rope(_, _, next_ki), + tCrQ_rope(_, _, next_ki)); cute::copy(smem_tiled_copy_K, - tSsK_rope(_, _, next_ki), - tSrK_rope_copy_view(_, _, next_ki)); + tCsK_rope(_, _, next_ki), + tCrK_rope(_, _, next_ki)); } - cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrAccS); + cute::gemm(tiled_mma, tSrQ_rope(_, _, ki), tSrK_rope(_, _, ki), tSrS); } }; @@ -228,69 +220,69 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { SmemTiledCopyVt smem_tiled_copy_Vt; auto smem_thr_copy_Vt = smem_tiled_copy_Vt.get_thread_slice(tidx); - auto tOsVt = smem_thr_copy_Vt.partition_S(sVt); - auto tOrVt_copy_view = smem_thr_copy_Vt.retile_D(tOrVt); + auto tCsVt = smem_thr_copy_Vt.partition_S(sVt); + auto tCrVt = smem_thr_copy_Vt.retile_D(tOrVt); // O = softmax(S)*V - // tSrAccS: (MMA,MMA_M,MMA_N) + // tSrS: (MMA,MMA_M,MMA_N) // tOrAccO: (MMA,MMA_M,MMA_K) - auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) { + auto compute_sv = [&](const auto& tSrS, auto& tOrO) { // cast scores from Accumulator to Element - auto tSrS = make_tensor_like(tSrAccS); - fast_cast(tSrAccS, tSrS); + auto tSrS_ = make_tensor_like(tSrS); + fast_cast(tSrS, tSrS_); // convert layout from gemm-I C to gemm-II A - auto tOrS = make_tensor(tSrS.data(), Layout::to_mma_a(tSrS.layout())); + auto tOrS = make_tensor(tSrS_.data(), Layout::to_mma_a(tSrS_.layout())); // prefetch V^t - cute::copy( - smem_tiled_copy_Vt, tOsVt(_, _, _0{}), tOrVt_copy_view(_, _, _0{})); + cute::copy(smem_tiled_copy_Vt, tCsVt(_, _, _0{}), tCrVt(_, _, _0{})); CUTE_UNROLL for (int ki = 0; ki < size<2>(tOrS); ++ki) { // prefetch next V^t if (ki != size<2>(tOrS) - 1) { const auto next_ki = ki + 1; - cute::copy(smem_tiled_copy_Vt, - tOsVt(_, _, next_ki), - tOrVt_copy_view(_, _, next_ki)); + cute::copy( + smem_tiled_copy_Vt, tCsVt(_, _, next_ki), tCrVt(_, _, next_ki)); } - cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrAccO); + cute::gemm(tiled_mma, tOrS(_, _, ki), tOrVt(_, _, ki), tOrO); } }; - // tOrAccO: (MMA,MMA_M,MMA_K) - auto epilogue = [&](const auto& tOrAccO) { + // tOrO: (MMA,MMA_M,MMA_K) + auto epilogue = [&](const auto& tOrO) { // write output to gmem // 1> cast output from ElementAccumulator to Element - auto tOrO = make_tensor_like(tOrAccO); - fast_cast(tOrAccO, tOrO); + auto tOrO_ = make_tensor_like(tOrO); + fast_cast(tOrO, tOrO_); - // 2. copy output from reg to smem (reuse sQ) auto sO = make_tensor(sQ.data(), SmemLayoutO{}); - - SmemTiledCopyO smem_tiled_copy_O; - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); - auto taccOrO = smem_thr_copy_O.retile_S(tOrO); - auto taccOsO = smem_thr_copy_O.partition_D(sO); - cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + // 2. copy output from reg to smem (reuse sQ) + { + SmemTiledCopyO smem_tiled_copy_O; + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + auto tCrO = smem_thr_copy_O.retile_S(tOrO_); + auto tCsO = smem_thr_copy_O.partition_D(sO); + cute::copy(smem_tiled_copy_O, tCrO, tCsO); + } // 3. copy output from smem to gmem - GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + { + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); - auto tOsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) - auto tOgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) + auto tCsO = gmem_thr_copy_O.partition_S(sO); // (CPY,CPY_M,CPY_K) + auto tCgO = gmem_thr_copy_O.partition_D(gO); // (CPY,CPY_M,CPY_K) - // wait for smem copy done before gmem copy - __syncthreads(); - cute::copy(gmem_tiled_copy_O, tOsO, tOgO); + // wait for smem copy done before gmem copy + __syncthreads(); + cute::copy(gmem_tiled_copy_O, tCsO, tCgO); + } }; // output accumulator, (MMA,MMA_M,MMA_K) - auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); - auto tOrAccO_rc_view = - make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout())); - clear(tOrAccO); + auto tOrO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{}); + auto tOrO_mn = make_tensor(tOrO.data(), Layout::to_rowcol(tOrO.layout())); + clear(tOrO); const int n_block_min = 0; const int n_block_max = cute::ceil_div(kv_len, kBlockN); @@ -311,23 +303,22 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { CUTE_NO_UNROLL for (int ni = n_block_min; ni < n_block_max; ++ni) { // attention score accumulator, (MMA,MMA_M,MMA_N) - auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); - auto tSrAccS_rc_view = - make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout())); - clear(tSrAccS); + auto tSrS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{}); + auto tSrS_mn = make_tensor(tSrS.data(), Layout::to_rowcol(tSrS.layout())); + clear(tSrS); // wait key, queue: [q, q_rope, kv, k_rope] => [] cp_async_wait<0>(); __syncthreads(); // 1> S = Q@K.T - compute_qk(tSrAccS); + compute_qk(tSrS); - // 2> S = Q@K.T + Q_rope@K_rope.T - compute_qk_rope(tSrAccS); + // 2> S += Q_rope@K_rope.T + compute_qk_rope(tSrS); // 3> O = softmax(S)*V - compute_sv(tSrAccS, tOrAccO); + compute_sv(tSrS, tOrO); // produce next key: [] => [kv, k_rope] if (ni != n_block_max - 1) { @@ -339,7 +330,7 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // ############### Epilogue ############### // write output to gmem - epilogue(tOrAccO); + epilogue(tOrO); } template Date: Fri, 7 Feb 2025 22:28:12 -0800 Subject: [PATCH 9/9] added softmax support --- src/kernels/attention/mla_kernel_sm80.cuh | 15 +++++++++++++-- src/kernels/attention/mla_kernel_sm80_test.cu | 11 +++++------ src/kernels/attention/mla_params.h | 9 +++++++-- src/kernels/attention/mla_ref.h | 16 +++++++++++----- src/kernels/attention/mla_sm80_bench.cu | 2 ++ src/kernels/attention/mla_traits_test.cpp | 12 ++++++------ 6 files changed, 44 insertions(+), 21 deletions(-) diff --git a/src/kernels/attention/mla_kernel_sm80.cuh b/src/kernels/attention/mla_kernel_sm80.cuh index 7716b2b4..09e9a077 100644 --- a/src/kernels/attention/mla_kernel_sm80.cuh +++ b/src/kernels/attention/mla_kernel_sm80.cuh @@ -30,7 +30,7 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { constexpr int kBlockN = Traits::kBlockN; constexpr int kHeadDim = Traits::kHeadDim; constexpr int kRopeHeadDim = Traits::kRopeHeadDim; - // constexpr int kRowsPerMMA = Traits::kRowsPerMMA; + constexpr int kRowsPerMMA = Traits::kRowsPerMMA; using _BLK_M = Int; using _BLK_N = Int; @@ -63,6 +63,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { const int batch_idx = blockIdx.y; const int tidx = threadIdx.x; + const float sm_scale_log2 = params.sm_scale_log2; + MLATile tile(params); // ProblemShape @@ -300,6 +302,10 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { cp_async_fence(); // ############### Mainloop ############### + constexpr int kMMA_M = size<1>(tOrO); + using Softmax = OnlineSoftmax; + Softmax softmax(sm_scale_log2); + CUTE_NO_UNROLL for (int ni = n_block_min; ni < n_block_max; ++ni) { // attention score accumulator, (MMA,MMA_M,MMA_N) @@ -317,6 +323,8 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { // 2> S += Q_rope@K_rope.T compute_qk_rope(tSrS); + softmax.rescale(tSrS_mn, tOrO_mn); + // 3> O = softmax(S)*V compute_sv(tSrS, tOrO); @@ -329,6 +337,10 @@ __global__ void mla_kernel_sm80(__grid_constant__ const Params params) { } // ############### Epilogue ############### + + // normalize output: o /= rowsum + softmax.finalize(tOrO_mn); + // write output to gmem epilogue(tOrO); } @@ -344,7 +356,6 @@ void launch_mla_kernel_sm80(const Params& params, cudaStream_t stream) { const auto max_q_packed_len = params.max_q_len * params.n_heads; const auto smem_size = Traits::kSmemSize; - print("smem_size: %d \n", smem_size); auto mla_kernel = mla_kernel_sm80; diff --git a/src/kernels/attention/mla_kernel_sm80_test.cu b/src/kernels/attention/mla_kernel_sm80_test.cu index 5096c39b..fcec6024 100644 --- a/src/kernels/attention/mla_kernel_sm80_test.cu +++ b/src/kernels/attention/mla_kernel_sm80_test.cu @@ -3,6 +3,7 @@ #include #include +#include #include "cute/numeric/numeric_types.hpp" #include "mla_kernel_sm80.cuh" // IWYU pragma: keep @@ -51,7 +52,8 @@ torch::Tensor mla_sm80( params.kv_len = kv_len; params.head_dim = head_dim; params.rope_head_dim = rope_head_dim; - // params.sm_scale = sm_scale; + params.sm_scale = sm_scale; + params.normalize(); using Traits = MLATraitsSM80; - launch_mla_kernel_sm80(params, nullptr); return out; } @@ -104,13 +105,11 @@ TEST_P(MLAKernelTest, MLA) { const auto k_rope = torch::randn({batch_size, kv_len, rope_head_dim}, options); - const float sm_scale = 1.0 / sqrt(head_dim); + const float sm_scale = 1.0 / sqrt(head_dim + rope_head_dim); auto ref_out = mla_batch_ref(q, kv, q_rope, k_rope, sm_scale); auto out = mla_sm80(q, kv, q_rope, k_rope, sm_scale); - - std::cerr << "max diff: " << (ref_out - out).abs().max() << std::endl; - EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-1, /*atol=*/1e-1)); + EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-3, /*atol=*/1e-3)); } INSTANTIATE_TEST_SUITE_P( diff --git a/src/kernels/attention/mla_params.h b/src/kernels/attention/mla_params.h index bcc88d9c..1e3415eb 100644 --- a/src/kernels/attention/mla_params.h +++ b/src/kernels/attention/mla_params.h @@ -9,8 +9,9 @@ namespace llm { // common params for attention kernels struct MLAParamsCommon { const void* __restrict__ q_ptr = nullptr; - const void* __restrict__ q_rope_ptr = nullptr; const void* __restrict__ kv_ptr = nullptr; + + const void* __restrict__ q_rope_ptr = nullptr; const void* __restrict__ k_rope_ptr = nullptr; void* __restrict__ o_ptr = nullptr; @@ -22,7 +23,8 @@ struct MLAParamsCommon { int head_dim = 0; int rope_head_dim = 0; - // int v_head_dim = 0; + // softmax scaling + float sm_scale = 1.0; // used for scheduling // TODO: remove it after persistent kernel @@ -31,6 +33,7 @@ struct MLAParamsCommon { // private: // used for performance optimization, don't change it bool normalized = false; + float sm_scale_log2 = 0.0; // used to initialize the params that used for performance optimization void normalize() { @@ -38,6 +41,8 @@ struct MLAParamsCommon { // already normalized return; } + sm_scale_log2 = static_cast(sm_scale * M_LOG2E); + normalized = true; } }; diff --git a/src/kernels/attention/mla_ref.h b/src/kernels/attention/mla_ref.h index 9c0f4165..abe84113 100644 --- a/src/kernels/attention/mla_ref.h +++ b/src/kernels/attention/mla_ref.h @@ -19,17 +19,23 @@ inline torch::Tensor mla_batch_ref( const auto qk_rope_head_dim = q_rope.size(-1); assert(kv_len >= q_len); + // use float32 for better precision + auto q_ = q.to(torch::kFloat); + auto kv_ = kv.to(torch::kFloat); + auto q_rope_ = q_rope.to(torch::kFloat); + auto k_rope_ = k_rope.to(torch::kFloat); + // query * key => [batch, q_len, n_heads, kv_len] - auto scores = torch::einsum("bqhr,bkr->bqhk", {q, kv}) + - torch::einsum("bqhp,bkp->bqhk", {q_rope, k_rope}); + auto scores = torch::einsum("bqhr,bkr->bqhk", {q_, kv_}) + + torch::einsum("bqhp,bkp->bqhk", {q_rope_, k_rope_}); // apply scale - // scores *= sm_scale; + scores *= sm_scale; // safe softmax - // scores = scores.softmax(/*dim=*/-1, /*dtype=*/torch::kFloat).type_as(q); + scores = torch::softmax(scores, /*dim=*/-1); // score * value => [batch_size, q_len, n_heads, kv_lora_rank] - return torch::einsum("bqhk,bkr->bqhr", {scores, kv}); + return torch::einsum("bqhk,bkr->bqhr", {scores, kv_}).type_as(q); } } // namespace llm \ No newline at end of file diff --git a/src/kernels/attention/mla_sm80_bench.cu b/src/kernels/attention/mla_sm80_bench.cu index 7385ffff..75be7925 100644 --- a/src/kernels/attention/mla_sm80_bench.cu +++ b/src/kernels/attention/mla_sm80_bench.cu @@ -83,6 +83,8 @@ void mla_bench_sm80(nvbench::state& state) { params.kv_len = kv_len; params.head_dim = head_dim; params.rope_head_dim = rope_head_dim; + params.sm_scale = 1.0; + params.normalize(); state.exec([&](nvbench::launch& launch) { DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] { diff --git a/src/kernels/attention/mla_traits_test.cpp b/src/kernels/attention/mla_traits_test.cpp index 195a2b51..44f75614 100644 --- a/src/kernels/attention/mla_traits_test.cpp +++ b/src/kernels/attention/mla_traits_test.cpp @@ -22,7 +22,7 @@ void test_mla_traits() { using SmemLayoutKRope = typename Traits::SmemLayoutKRope; using SmemLayoutVt = typename Traits::SmemLayoutVt; using SmemLayoutO = typename Traits::SmemLayoutO; - + using GmemTiledCopyQ = typename Traits::GmemTiledCopyQ; using GmemTiledCopyKV = typename Traits::GmemTiledCopyKV; using GmemTiledCopyO = typename Traits::GmemTiledCopyO; @@ -40,13 +40,13 @@ void test_mla_traits() { Tensor sQ_rope = make_tensor(counting_iterator(0), SmemLayoutQRope{}); Tensor sKV_rope = make_tensor(counting_iterator(0), SmemLayoutKRope{}); - print("sQ:"); print(sQ);print("\n"); - print("sKV:"); print(sKV);print("\n"); + // print("sQ:"); print(sQ);print("\n"); + // print("sKV:"); print(sKV);print("\n"); - print("sQ_rope:"); print(sQ_rope);print("\n"); - print("sKV_rope:"); print(sKV_rope);print("\n"); + // print("sQ_rope:"); print(sQ_rope);print("\n"); + // print("sKV_rope:"); print(sKV_rope);print("\n"); - print("sVt:"); print(sVt);print("\n"); + // print("sVt:"); print(sVt);print("\n"); TiledMma tiled_mma; auto thr_mma = tiled_mma.get_slice(0);