Skip to content

Commit

Permalink
dispatch cp.async
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Oct 20, 2023
1 parent f8020e3 commit 90f5b8f
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "decoder_multihead_attention_template.h"
#include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/cuda_utils.h"

#include <iostream>

Expand Down Expand Up @@ -34,39 +35,46 @@ bool Print(size_t dynamic_smem_size)
template<typename T, typename Tkv, int HeadDim, int HeadPerCta>
void invokeDecoderMultiheadAttention(const DecoderMultiHeadAttentionParams<T>& params)
{
// cpasync_2048_32x6 ~ 64k smem
// using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>;
auto invoke = [&](auto* type) {
using Attn = std::remove_reference_t<decltype(*type)>;

using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
static const size_t kDynSmemSize = Attn::GetDynamicSmemSize();

// ld_kv16_2048_32x3 ~ 34k smem
// using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>;
[[maybe_unused]] static const bool _ = Print<Attn>(kDynSmemSize);

// ld_kv8_2048_64x3 ~ 34k smem
// using MHAType = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HedDim, 64, HeadDim, 2048, 3>;
const int slice_count = (params.max_seq_len + Attn::kSliceLen - 1) / Attn::kSliceLen;
const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));

static const size_t kDynSmemSize = MHAType::GetDynamicSmemSize();
dim3 block(Attn::kWarpCount * WARP_SIZE);
dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);

[[maybe_unused]] static const bool _ = Print<MHAType>(kDynSmemSize);
// if (params.layer_offset == 0) {
// std::cout << "max_split_k' = " << max_split_k << ", arch = " << params.arch << "\n";
// }

const int slice_count = (params.max_seq_len + MHAType::kSliceLen - 1) / MHAType::kSliceLen;
const int max_split_k = std::min(params.max_split_k, std::max(1, slice_count));
cudaFuncSetAttribute(
decoder_multihead_attention<Attn>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);

dim3 block(MHAType::kWarpCount * WARP_SIZE);
dim3 grid(params.num_heads / HeadPerCta, params.batch_size, max_split_k);
decoder_multihead_attention<Attn><<<grid, block, kDynSmemSize, params.stream>>>(params);

// if (params.layer_offset == 0) {
// std::cout << "max_split_k' = " << max_split_k << "\n";
// }
if (max_split_k > 1) {
dim3 grid(params.num_heads, params.batch_size);
decoder_multihead_attention_reduce<Attn><<<grid, block, 0, params.stream>>>(params);
}
};

cudaFuncSetAttribute(
decoder_multihead_attention<MHAType>, cudaFuncAttributeMaxDynamicSharedMemorySize, kDynSmemSize);
if (params.arch >= 80) {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 6>; // 64k

decoder_multihead_attention<MHAType><<<grid, block, kDynSmemSize, params.stream>>>(params);
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 1024, 5, true>;
invoke((Type*)0);
}
else {
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 32, HeadDim, 2048, 3>; // 34k
// DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 2048, 3>; // 34k

if (max_split_k > 1) {
dim3 grid(params.num_heads, params.batch_size);
decoder_multihead_attention_reduce<MHAType><<<grid, block, 0, params.stream>>>(params);
using Type = DecoderMultiHeadAttentionKernel<T, Tkv, HeadPerCta, HeadDim, 64, HeadDim, 1024, 3, true>;
invoke((Type*)0);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct DecoderMultiHeadAttentionParams {
float* partial_M;
float* partial_L;

int arch;
cudaStream_t stream;
};

Expand Down
23 changes: 19 additions & 4 deletions src/turbomind/kernels/decoder_multihead_attention/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@

namespace turbomind {

#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif

struct BlockIterator {
const void** ptrs_;
const void* prefetch_;
Expand Down Expand Up @@ -256,15 +262,20 @@ struct Iterator {
{
const int smem_int_ptr = cast_smem_ptr_to_uint(dst);
constexpr int cp_size = sizeof(AccessType);
// static_assert(cp_size == 16);
#if TURBOMIND_ARCH_SM80
// clang-format off
asm volatile("{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %0, 0;\n"
" @p cp.async.ca.shared.global.L2::128B [%1], [%2], %3;\n"
" @p cp.async.ca.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;\n"
"}\n" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
// clang-format on
#else
assert(TURBOMIND_ARCH_SM80);
#endif
}

static __device__ void Copy(T* __restrict__ dst, const T* __restrict__ src, bool mask)
Expand All @@ -276,8 +287,12 @@ struct Iterator {

__device__ void Prefetch(bool mask)
{
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
// Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
if constexpr (TURBOMIND_ARCH_SM80) {
CpAsync(smem_ + dst_offset_, src_ + src_offset_, mask);
}
else {
Copy(smem_ + dst_offset_, src_ + src_offset_, mask);
}
}

__device__ void Load(AccessType (&frag)[ThreadMap::kIterC])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ int main(int argc, char* argv[])
params.max_split_k = kMaxSplitK;
params.max_seq_len = kContextLen;

params.arch = 80;

std::vector<thrust::universal_vector<half>> outputs;

for (int i = 0; i < std::max(kTestIter, 10); ++i) {
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o

const float avg_batch_size = max_seq_len ? (float)sum_seq_len / max_seq_len : 1;
FT_CHECK(avg_batch_size >= 1.f);

const int max_split_k = std::max(1, (int)std::ceil(kMaxSplitK / avg_batch_size));

// if (layer_id == 0) {
Expand All @@ -161,6 +161,7 @@ void LlamaDecoderSelfAttentionLayer<T>::forward(TensorMap* o
params.max_split_k = max_split_k;
params.max_seq_len = max_seq_len;

params.arch = arch_;
params.stream = stream_;

params.quant_policy = quant_policy_;
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h"

namespace turbomind {
Expand Down Expand Up @@ -60,6 +61,7 @@ class LlamaDecoderSelfAttentionLayer {
is_free_buffer_after_forward_(is_free_buffer_after_forward),
quant_policy_(quant_policy)
{
arch_ = getSMVersion();
}

~LlamaDecoderSelfAttentionLayer()
Expand Down Expand Up @@ -96,6 +98,7 @@ class LlamaDecoderSelfAttentionLayer {
float* workspace_ = nullptr;

bool is_allocate_buffer_{};
int arch_{};
};

} // namespace turbomind

0 comments on commit 90f5b8f

Please sign in to comment.