From 90f5b8fef7009a2cccf2993216da5c3d32db3fb1 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Fri, 20 Oct 2023 05:07:41 +0000 Subject: [PATCH] dispatch `cp.async` --- .../decoder_multihead_attention.cu | 52 +++++++++++-------- .../decoder_multihead_attention_params.h | 1 + .../decoder_multihead_attention/iterator.h | 23 ++++++-- .../test_decoder_multihead_attention.cu | 2 + .../llama/LlamaDecoderSelfAttentionLayer.cc | 3 +- .../llama/LlamaDecoderSelfAttentionLayer.h | 3 ++ 6 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu index b83de12bd3..709db6ebc0 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention.cu @@ -2,6 +2,7 @@ #include "decoder_multihead_attention_template.h" #include "src/turbomind/models/llama/llama_utils.h" +#include "src/turbomind/utils/cuda_utils.h" #include @@ -34,39 +35,46 @@ bool Print(size_t dynamic_smem_size) template void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams& params) { - // cpasync_2048_32x6 ~ 64k smem - // using MHAType = DecoderMultiHeadAttentionKernel; + auto invoke = [&](auto* type) { + using Attn = std::remove_reference_t; - using MHAType = DecoderMultiHeadAttentionKernel; + static const size_t kDynSmemSize = Attn::GetDynamicSmemSize(); - // ld_kv16_2048_32x3 ~ 34k smem - // using MHAType = DecoderMultiHeadAttentionKernel; + [[maybe_unused]] static const bool _ = Print(kDynSmemSize); - // ld_kv8_2048_64x3 ~ 34k smem - // using MHAType = DecoderMultiHeadAttentionKernel; + const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen; + const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count)); - static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize(); + dim3 block(Attn::kWarpCount * WARP_SIZE); + dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k); - [[maybe_unused]] static const bool _ = Print(kDynSmemSize); + // if (params.layer_offset == 0) { + // std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n"; + // } - const int slice_count = (params.max_seq_len + MHAType::kSliceLen - 1) / MHAType::kSliceLen; - const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count)); + cudaFuncSetAttribute( + decoder_multihead_attention, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize); - dim3 block(MHAType::kWarpCount * WARP_SIZE); - dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k); + decoder_multihead_attention<<>>(params); - // if (params.layer_offset == 0) { - // std::cout << "max_split_k' = " << max_split_k << "\n"; - // } + if (max_split_k > 1) { + dim3 grid(params.num_heads, params.batch_size); + decoder_multihead_attention_reduce<<>>(params); + } + }; - cudaFuncSetAttribute( - decoder_multihead_attention, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize); + if (params.arch >= 80) { + // DecoderMultiHeadAttentionKernel; // 64k - decoder_multihead_attention<<>>(params); + using Type = DecoderMultiHeadAttentionKernel; + invoke((Type*)0); + } + else { + // DecoderMultiHeadAttentionKernel; // 34k + // DecoderMultiHeadAttentionKernel; // 34k - if (max_split_k > 1) { - dim3 grid(params.num_heads, params.batch_size); - decoder_multihead_attention_reduce<<>>(params); + using Type = DecoderMultiHeadAttentionKernel; + invoke((Type*)0); } } diff --git a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h index 993001ee89..5f18b45216 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h +++ b/src/turbomind/kernels/decoder_multihead_attention/decoder_multihead_attention_params.h @@ -63,6 +63,7 @@ struct DecoderMultiHeadAttentionParams { float* partial_M; float* partial_L; + int arch; cudaStream_t stream; }; diff --git a/src/turbomind/kernels/decoder_multihead_attention/iterator.h b/src/turbomind/kernels/decoder_multihead_attention/iterator.h index 006d1e5cc6..5e0ba7f885 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/iterator.h +++ b/src/turbomind/kernels/decoder_multihead_attention/iterator.h @@ -7,6 +7,12 @@ namespace turbomind { +#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4) +#define L2_CACHEHINT(size) ".L2::" #size "B" +#else +#define L2_CACHEHINT(size) +#endif + struct BlockIterator { const void** ptrs_; const void* prefetch_; @@ -256,15 +262,20 @@ struct Iterator { { const int smem_int_ptr = cast_smem_ptr_to_uint(dst); constexpr int cp_size = sizeof(AccessType); - // static_assert(cp_size == 16); +#if TURBOMIND_ARCH_SM80 + // clang-format off asm volatile("{\n" " .reg .pred p;\n" " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n" + " @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n" "}\n" ::"r"((int)mask), "r"(smem_int_ptr), "l"(src), "n"(cp_size)); + // clang-format on +#else + assert(TURBOMIND_ARCH_SM80); +#endif } static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask) @@ -276,8 +287,12 @@ struct Iterator { __device__ void Prefetch(bool mask) { - CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask); - // Copy(smem_ + dst_offset_, src_ + src_offset_, mask); + if constexpr (TURBOMIND_ARCH_SM80) { + CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask); + } + else { + Copy(smem_ + dst_offset_, src_ + src_offset_, mask); + } } __device__ void Load(AccessType (&frag)[ThreadMap::kIterC]) diff --git a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu index 912bff1ae4..b5249f31c2 100644 --- a/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu +++ b/src/turbomind/kernels/decoder_multihead_attention/test_decoder_multihead_attention.cu @@ -252,6 +252,8 @@ int main(int argc, char* argv[]) params.max_split_k = kMaxSplitK; params.max_seq_len = kContextLen; + params.arch = 80; + std::vector> outputs; for (int i = 0; i < std::max(kTestIter, 10); ++i) { diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc index ba31d90a0d..78ced5dff8 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -151,7 +151,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1; FT_CHECK(avg_batch_size >= 1.f); - + const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size)); // if (layer_id == 0) { @@ -161,6 +161,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o params.max_split_k = max_split_k; params.max_seq_len = max_seq_len; + params.arch = arch_; params.stream = stream_; params.quant_policy = quant_policy_; diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h index 9f80edc462..95556cd30b 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h @@ -24,6 +24,7 @@ #include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/utils/Tensor.h" +#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/nccl_utils.h" namespace turbomind { @@ -60,6 +61,7 @@ class LlamaDecoderSelfAttentionLayer { is_free_buffer_after_forward_(is_free_buffer_after_forward), quant_policy_(quant_policy) { + arch_ = getSMVersion(); } ~LlamaDecoderSelfAttentionLayer() @@ -96,6 +98,7 @@ class LlamaDecoderSelfAttentionLayer { float* workspace_ = nullptr; bool is_allocate_buffer_{}; + int arch_{}; }; } // namespace turbomind