Skip to content

Commit

Permalink
kernel: optimize attention kernel performance (#377)
Browse files Browse the repository at this point in the history
1> use more static dispatch
2> use cutlass::FastDivmod for slot id calculation
3> handle oob for k (head_dim)
  • Loading branch information
guocuimi authored Jan 18, 2025
1 parent e030191 commit 65b3c53
Show file tree
Hide file tree
Showing 16 changed files with 378 additions and 237 deletions.
7 changes: 4 additions & 3 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ cc_library(
ptx.cuh
fast_cast.cuh
online_softmax.cuh
mask.h
static_dispatch.h
attention_params.h
attention_tile.h
attention_traits_sm80.h
attention_kernel_sm80.cuh
static_dispatch.h
attention_launch_sm80.cuh
DEPS
cutlass
Expand Down Expand Up @@ -53,13 +54,13 @@ cc_test(
NAME
attention_kernel_test
SRCS
attention_cpu_test.cpp
# attention_cpu_test.cpp
attention_traits_test.cpp
attention_kernel_sm80_test.cu
attention_kernel_sm80_varlen_test.cu
attention_kernel_sm80_pagedkv_test.cu
DEPS
:attention.template
:attention.kernel
absl::random_random
GTest::gtest_main
torch
Expand Down
20 changes: 16 additions & 4 deletions src/kernels/attention/attention_bench_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@

using namespace llm;

#define DISPATCH_HEAD_DIM_(HEAD_DIM_V, HEAD_DIM_NAME, ...) \
[&] { \
if (HEAD_DIM_V <= 64) { \
constexpr static int HEAD_DIM_NAME = 64; \
return __VA_ARGS__(); \
} else if (HEAD_DIM_V <= 128) { \
constexpr static int HEAD_DIM_NAME = 128; \
return __VA_ARGS__(); \
} else { \
assert(false); \
} \
}()

void attention_bench_sm80(nvbench::state& state) {
// Collect CUPTI metrics
state.collect_cupti_metrics();
Expand Down Expand Up @@ -60,10 +73,9 @@ void attention_bench_sm80(nvbench::state& state) {
params.sliding_window = -1;

state.exec([&](nvbench::launch& launch) {
DISPATCH_TORCH_DTYPE(query.dtype(), QTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<QTYPE, HEAD_DIM>(params, launch.get_stream());
});
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<cute::half_t, HEAD_DIM>(params,
launch.get_stream());
});
});
}
Expand Down
10 changes: 8 additions & 2 deletions src/kernels/attention/attention_cpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ class AttentionCPUTest
int64_t /*q_seq_len*/,
int64_t /*n_heads*/,
int64_t /*n_kv_heads*/,
int64_t /*head_dim*/>> {};
int64_t /*head_dim*/>> {
public:
void SetUp() override {
// Set random seed for test stability
torch::manual_seed(0);
}
};

TEST_P(AttentionCPUTest, MHA) {
const auto [seq_len, q_seq_len, n_heads, n_kv_heads, head_dim] = GetParam();
Expand All @@ -66,7 +72,7 @@ TEST_P(AttentionCPUTest, MHA) {

auto out = torch::empty_like(query);
mha(query, key, value, out);
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-5, /*atol=*/1e-5));
EXPECT_TRUE(torch::allclose(out, ref_out, /*rtol=*/1e-4, /*atol=*/1e-4));
}

INSTANTIATE_TEST_SUITE_P(
Expand Down
167 changes: 94 additions & 73 deletions src/kernels/attention/attention_kernel_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,40 @@
#include <cute/tensor.hpp>

#include "attention_tile.h"
#include "cute/config.hpp"
#include "cute_extensions.cuh"
#include "fast_cast.cuh"
#include "mask.h"
#include "online_softmax.cuh"
#include "ptx.cuh"

namespace llm {

template <typename Traits, typename Params>
__global__ void mha_kernel_sm80(Params params) {
template <typename Traits,
typename Params,
bool EVEN_K,
bool ALIBI,
bool SOFT_CAP,
bool LOCAL>
__global__ void mha_kernel_sm80(__grid_constant__ const Params params) {
using namespace cute;

constexpr int kBlockM = Traits::kBlockM;
constexpr int kBlockN = Traits::kBlockN;
constexpr int kBlockK = Traits::kBlockK;
constexpr int kHeadDim = Traits::kHeadDim;
constexpr int kRowsPerMMA = Traits::kRowsPerMMA;

using _BLK_M = Int<kBlockM>;
using _BLK_N = Int<kBlockN>;
using _BLK_K = Int<kBlockK>;
using _HEAD_DIM = Int<kHeadDim>;

// type alias
using Element = typename Traits::Element;
using DType = typename Traits::DType;

using TiledMma = typename Traits::TiledMma;
using Layout = typename Traits::LayoutConvertor;
using Softmax = typename Traits::Softmax;
using Mask = typename Traits::Mask;

using SmemLayoutQ = typename Traits::SmemLayoutQ;
using SmemLayoutK = typename Traits::SmemLayoutK;
Expand All @@ -55,46 +62,13 @@ __global__ void mha_kernel_sm80(Params params) {

AttentionTile<Params> tile(params);

float logits_soft_cap = params.logits_soft_cap;
float sm_scale = params.sm_scale;
float sliding_window = params.sliding_window;

// preprocess input parameters
// TODO: Move following logic to the host side?
if (logits_soft_cap != 0.0) {
// Softmax(x * sm_scale) + apply_logits_soft_cap
// => Softmax(Tanh(x * sm_scale / soft_cap) * soft_cap)
// => Softmax(S' * sm_scale') where
// S' = Tanh(x * sm_scale / soft_cap)
// = Tanh(x * soft_cap')
// soft_cap' = sm_scale / soft_cap
// sm_scale' = soft_cap
const auto sm_scale_hat = logits_soft_cap;
logits_soft_cap = sm_scale * ptx::rcp(logits_soft_cap);
sm_scale = sm_scale_hat;
}
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
CUTE_UNROLL
for (int i = 0; i < size(tSrAccS); ++i) {
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
}
};

const float alibi_slope = params.alibi_slopes_ptr != nullptr
? (params.alibi_slopes_ptr[head_idx] / sm_scale)
: 0.0f;

const int group_size = params.n_heads / params.n_kv_heads;

// use exp2f instead of expf for better performance
sm_scale *= M_LOG2E;

// ProblemShape
// (q_len, HEAD_DIM)
auto [Q, O] = tile.template get_qo_tile<Element>(batch_idx, head_idx);
auto [Q, O] = tile.template get_qo_tile<DType>(batch_idx, head_idx);
// (kv_len, HEAD_DIM)
auto [K, V] =
tile.template get_kv_tile<Element>(batch_idx, head_idx / group_size);
tile.template get_kv_tile<DType>(batch_idx, head_idx / group_size);

const int q_len = size<0>(Q.shape());
const int kv_len = size<0>(K.shape());
Expand All @@ -104,10 +78,23 @@ __global__ void mha_kernel_sm80(Params params) {
return;
}

// adjust sliding window size
if (sliding_window < 0) {
sliding_window = kv_len;
}
const int head_dim = params.head_dim;
const float logits_soft_cap = params.logits_soft_cap;
const float sm_scale = params.sm_scale;
const float sm_scale_log2 = params.sm_scale_log2;
const float sliding_window = LOCAL ? params.sliding_window : kv_len;
const float alibi_slope =
ALIBI ? (params.alibi_slopes_ptr[head_idx] / sm_scale) : 0.0f;

// preprocess input parameters
auto apply_logits_soft_cap = [&](auto& tSrAccS) {
if constexpr (SOFT_CAP) {
CUTE_UNROLL
for (int i = 0; i < size(tSrAccS); ++i) {
tSrAccS(i) = ptx::tanh(tSrAccS(i) * logits_soft_cap);
}
}
};

// Gmem
// (BLK_M, HEAD_DIM)
Expand All @@ -121,9 +108,9 @@ __global__ void mha_kernel_sm80(Params params) {

// Smem
extern __shared__ char smem[];
Element* q_smem = (Element*)smem;
Element* k_smem = q_smem + cosize(SmemLayoutQ{});
Element* v_smem = k_smem + cosize(SmemLayoutK{});
DType* q_smem = (DType*)smem;
DType* k_smem = q_smem + cosize(SmemLayoutQ{});
DType* v_smem = k_smem + cosize(SmemLayoutK{});

// (BLK_M, BLK_K), k-major
Tensor sQ = make_tensor(make_smem_ptr(q_smem), SmemLayoutQ{});
Expand All @@ -140,33 +127,57 @@ __global__ void mha_kernel_sm80(Params params) {
GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);

auto produce_q = [&]() {
// (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
auto cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{});
auto tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
// coordinate tensor for oob handling
// (BLK_M, HEAD_DIM) -> (blk_m, head_dim)
Tensor cQ = make_identity_tensor(Shape<_BLK_M, _HEAD_DIM>{});
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ);
// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
Tensor tKcKV = gmem_thr_copy_QKV.partition_S(cKV);

auto produce_q = [&]() {
auto tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
auto tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
safe_copy</*ZERO_FILL=*/true>(
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, q_len - m_block * kBlockM);
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_QKV,
tQgQ,
tQsQ,
tQcQ,
make_coord(q_len - m_block * kBlockM, head_dim));
};

// (BLK_N, HEAD_DIM) -> (blk_n, head_dim)
Tensor cKV = make_identity_tensor(Shape<_BLK_N, _HEAD_DIM>{});
Tensor tKcKV = gmem_thr_copy_QKV.partition_S(cKV);

// TODO: seperate mask iterations
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
auto produce_k = [&](int ni) {
auto tKgK = gmem_thr_copy_QKV.partition_S(gK(_, _, ni));
safe_copy</*ZERO_FILL=*/true>(
gmem_tiled_copy_QKV, tKgK, tKsK, tKcKV, kv_len - ni * kBlockN);
// skip zero fill oob for k since mask will mask out oob with -inf
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_QKV,
tKgK,
tKsK,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
auto produce_v = [&](int ni) {
auto tVgV = gmem_thr_copy_QKV.partition_S(gV(_, _, ni));
safe_copy</*ZERO_FILL=*/true>(
gmem_tiled_copy_QKV, tVgV, tVsV, tKcKV, kv_len - ni * kBlockN);
// TODO: skip zero fill oob for v, may have nan issue
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/true,
/*ZERO_FILL_K=*/true>(
gmem_tiled_copy_QKV,
tVgV,
tVsV,
tKcKV,
make_coord(kv_len - ni * kBlockN, head_dim));
};

TiledMma tiled_mma;
Expand Down Expand Up @@ -222,7 +233,7 @@ __global__ void mha_kernel_sm80(Params params) {
// tOrAccO: (MMA,MMA_M,MMA_K)
auto compute_sv = [&](const auto& tSrAccS, auto& tOrAccO) {
// cast scores from Accumulator to Element
auto tSrS = make_tensor_like<Element>(tSrAccS);
auto tSrS = make_tensor_like<DType>(tSrAccS);
fast_cast(tSrAccS, tSrS);

// convert layout from gemm-I C to gemm-II A
Expand All @@ -248,7 +259,7 @@ __global__ void mha_kernel_sm80(Params params) {
auto epilogue = [&](const auto& tOrAccO) {
// write output to gmem
// 1> cast output from ElementAccumulator to Element
auto tOrO = make_tensor_like<Element>(tOrAccO);
auto tOrO = make_tensor_like<DType>(tOrAccO);
fast_cast(tOrAccO, tOrO);

// 2. copy output from reg to smem (reuse sQ)
Expand All @@ -274,8 +285,15 @@ __global__ void mha_kernel_sm80(Params params) {

// wait for smem copy done before gmem copy
__syncthreads();
safe_copy</*ZERO_FILL=*/false>(
gmem_tiled_copy_O, tOsO, tOgO, tOcO, q_len - m_block * kBlockM);
safe_copy<EVEN_K,
/*EVEN_MN=*/false,
/*ZERO_FILL_MN=*/false,
/*ZERO_FILL_K=*/false>(
gmem_tiled_copy_O,
tOsO,
tOgO,
tOcO,
make_coord(q_len - m_block * kBlockM, head_dim));
};

// ############### Prologue ###############
Expand All @@ -293,20 +311,23 @@ __global__ void mha_kernel_sm80(Params params) {
auto tOrAccO = partition_fragment_C(tiled_mma, Shape<_BLK_M, _HEAD_DIM>{});
auto tOrAccO_rc_view =
make_tensor(tOrAccO.data(), Layout::to_rowcol(tOrAccO.layout()));
clear(tOrAccO);

Softmax softmax(sm_scale);
Mask mask(q_len, kv_len, sliding_window, alibi_slope);
// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrAccS_rc_view =
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));

OnlineSoftmax<kRowsPerMMA * size<1>(tOrAccO)> softmax(sm_scale_log2);
Mask<kBlockM, kBlockM, ALIBI, LOCAL> mask(
q_len, kv_len, sliding_window, alibi_slope);

// TODO: control block min/max precisely
const int n_block_min = 0;
const int n_block_max = cute::ceil_div(kv_len, kBlockN);

clear(tOrAccO);
CUTE_NO_UNROLL
for (int ni = n_block_min; ni < n_block_max; ++ni) {
// attention score accumulator, (MMA,MMA_M,MMA_N)
auto tSrAccS = partition_fragment_C(tiled_mma, Shape<_BLK_M, _BLK_N>{});
auto tSrAccS_rc_view =
make_tensor(tSrAccS.data(), Layout::to_rowcol(tSrAccS.layout()));
clear(tSrAccS);

// wait k, queue: [q, k] => []
Expand All @@ -321,7 +342,7 @@ __global__ void mha_kernel_sm80(Params params) {
compute_qk(tSrAccS);

// apply soft cap if needed
if (logits_soft_cap != 0.0) {
if constexpr (SOFT_CAP) {
apply_logits_soft_cap(tSrAccS);
}

Expand Down
6 changes: 3 additions & 3 deletions src/kernels/attention/attention_kernel_sm80_pagedkv_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ torch::Tensor attention_pagedkv_sm80(
params.block_cu_lens = block_cu_lens.const_data_ptr<int32_t>();
params.block_size = block_size;

DISPATCH_TORCH_DTYPE(query.dtype(), QTYPE, [&] {
DISPATCH_TORCH_DTYPE(query.dtype(), DTYPE, [&] {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, [&] {
run_attention_kernel_sm80<QTYPE, HEAD_DIM>(params);
run_attention_kernel_sm80<DTYPE, HEAD_DIM>(params);
});
});
return out;
Expand Down Expand Up @@ -224,7 +224,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(127, 1000), // max_kv_len
::testing::Values(6), // n_heads
::testing::Values(6 /*mha*/, 3 /*gqa*/, 1 /*mqa*/), // n_kv_heads
::testing::Values(64), // head_dim
::testing::Values(32, 64, 96, 128, 256), // head_dim
::testing::Values(0.0, 50.0), // logits_soft_cap
::testing::Values(false, true), // alibi slope
::testing::Values(-1, 0, 10) // sliding window
Expand Down
Loading

0 comments on commit 65b3c53

Please sign in to comment.