Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernel: added simple MLA kernel #396

Merged
merged 10 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ cc_library(
mha_traits_sm80.h
mha_kernel_sm80.cuh
mha_dispatch_sm80.cuh
mla_params.h
mla_tile.h
mla_traits_sm80.h
mla_kernel_sm80.cuh
DEPS
cutlass
)
Expand Down Expand Up @@ -67,6 +71,19 @@ cc_test(
torch
)

cc_test(
NAME
mla_kernel_test
SRCS
mla_traits_test.cpp
mla_kernel_sm80_test.cu
DEPS
:attention.template
absl::random_random
GTest::gtest_main
torch
)

nvbench_binary(
NAME
mha_sm80_bench
Expand All @@ -86,4 +103,13 @@ nvbench_binary(
:attention.template
)

nvbench_binary(
NAME
mla_sm80_bench
SRCS
mla_sm80_bench.cu
DEPS
:attention.template
)

add_subdirectory(tools)
6 changes: 6 additions & 0 deletions src/kernels/attention/cute_extensions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ constexpr bool
.with(declval<bool>()))>> = true;
} // namespace detail

template <int... Is, int B, int M, int S, class Offset, class LayoutB>
CUTE_HOST_DEVICE constexpr auto permute(
const ComposedLayout<Swizzle<B, M, S>, Offset, LayoutB>& c) {
return composition(c.layout_a(), c.offset(), select<Is...>(c.layout_b()));
}

template <size_t I, class IntTupleA, class IntTupleB>
CUTE_HOST_DEVICE constexpr auto elem_less(IntTupleA const& a,
IntTupleB const& b) {
Expand Down
10 changes: 10 additions & 0 deletions src/kernels/attention/mha_dispatch_sm80.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ void run_mha_kernel_sm80(Params& params, cudaStream_t stream = nullptr) {
params.normalize();

// TODO: tune block shape MNK based on the head dim and smem size
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
// SM | 7.0 | 7.2 | 7.5 | 8.0 | 8.6 | 8.7 | 8.9 | 9.0 | 10.x | 12.0|
// Max SMEM (KB)| 96 | 64 | 164 | 100 | 164 | 100 | 228 | 100 |
// valid dynamic shared memory sizes for different compute capabilities:
// * 7.0 | 7.2 : 0, 8, 16, 32, 64, 96
// * 7.5 : 0, 32, 64
// * 8.0 | 8.7 : 0, 8, 16, 32, 64, 100, 132, 164
// * 8.6 | 8.9 : 0, 8, 16, 32, 64, 100
// * 9.0 | 10.x: 0, 8, 16, 32, 64, 100, 132, 164, 196, 228
// * 12.0 : 0, 8, 16, 32, 64, 100
if constexpr (HEAD_DIM == 64) {
using Traits = MHATraitsSM80<Dtype,
HEAD_DIM,
Expand Down
7 changes: 3 additions & 4 deletions src/kernels/attention/mha_traits_sm80.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include <cute/config.hpp>
#include <cute/tensor.hpp>
#include "cute_extensions.cuh"

namespace llm {
using namespace cute;
Expand Down Expand Up @@ -79,10 +80,8 @@ struct MHATraitsSM80 {
using SmemLayoutV =
decltype(tile_to_shape(SmemLayoutAtom{}, Shape<_BLK_N, _HEAD_DIM>{}));

// V^T smem: (HEAD_DIM, BLK_N) row-major
using SmemLayoutVt = decltype(composition(
SmemLayoutV{},
make_layout(Shape<_HEAD_DIM, _BLK_N>{}, GenRowMajor{})));
// V^T smem: (HEAD_DIM, BLK_N)
using SmemLayoutVt = decltype(permute<1, 0>(SmemLayoutV{}));

// Thr layout for gmem copy
using GmemCopyThrLayout =
Expand Down
Loading
Loading