diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index f12ecf1f42..93ca7a1817 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -387,7 +387,7 @@ int test_attention() attn_param.rope.base = kRoPEBase; attn_param.rope.dim = kRoPEDim; attn_param.rope.factor = 1.0f; - auto rotary_emb = std::make_unique>(attn_param, kInputLen, nullptr, allocator.get()); + auto rotary_emb = std::make_unique>(attn_param, 0, nullptr, allocator.get()); RotaryEmbeddingV2Param rotary_param; rotary_param.rope_theta = rope_base.data().get(); diff --git a/src/turbomind/models/llama/rotary_emb.cu b/src/turbomind/models/llama/rotary_emb.cu index bb5e8c6197..8cec2ef5ae 100644 --- a/src/turbomind/models/llama/rotary_emb.cu +++ b/src/turbomind/models/llama/rotary_emb.cu @@ -109,9 +109,33 @@ __global__ void computeQ2P(const int* q2b, const int* q_len, const int* k_len, i } } +template +__global__ void computeQ2PDynamic(int token_num, int* q2p) +{ + Array p; + + size_t thread_id = threadIdx.x + blockIdx.x * blockDim.x; + size_t index = thread_id * iterms_per_thread; + for (int i = 0; i < iterms_per_thread; ++i) { + int qi = index + i; + if (qi < token_num) { + p[i] = qi; + } + } + if (index < token_num) { + Store(&q2p[index], p); + } +} + template -__global__ void rotaryEmbeddingDynamic( - const int* q2b, const int* q2p, const float* rope_base, int token_num, int dim, float factor, T* cos_sin) +__global__ void rotaryEmbeddingDynamic(const int* q2b, + const int* q_len, + const int* k_len, + const float* rope_base, + int token_num, + int dim, + float factor, + T* cos_sin) { int thread_id = threadIdx.x + blockIdx.x * blockDim.x; int thread_per_tok = dim / items_per_thread; @@ -121,9 +145,10 @@ __global__ void rotaryEmbeddingDynamic( return; } - int bid = q2b[qi] + 1; - float base = rope_base[bid - 1]; - float ti = (float)q2p[qi]; + int bid = q2b[qi] + 1; + int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]); + float base = rope_base[bid - 1]; + float ti = history_len + qi - q_len[bid - 1]; Array cs; float c, s; @@ -233,9 +258,11 @@ RotaryEmbeddingV2::RotaryEmbeddingV2(const AttentionParam& param, break; } - cos_sin_ = (T*)allocator_->reMalloc(cos_sin_, sizeof(T) * session_len * dim_); - allocateBuffer(session_len); - computeCache(session_len); + if (session_len) { + cos_sin_ = (T*)allocator_->reMalloc(cos_sin_, sizeof(T) * session_len * dim_); + allocateBuffer(session_len); + computeCache(session_len); + } } template @@ -248,10 +275,11 @@ void RotaryEmbeddingV2::computeCache(int session_len) switch (type_) { case RopeType::kDefault: case RopeType::kLinear: - case RopeType::kDynamic: rotaryEmbedding <<>>(rope_base_, session_len, dim_, inv_factor_, CosSinDefault{}, cos_sin_); break; + case RopeType::kDynamic: + break; case RopeType::kLlama3: rotaryEmbedding<<>>( rope_base_, session_len, dim_, inv_factor_, CosSinLlama3{llama3_}, cos_sin_); @@ -275,8 +303,8 @@ void RotaryEmbeddingV2::updateCache(const RotaryEmbeddingV2Param& params) const int items_per_thread = 8; const int block = 256; const int grid = (dim_ / items_per_thread * params.token_num + block - 1) / block; - rotaryEmbeddingDynamic - <<>>(q2b_, q2p_, params.rope_theta, params.token_num, dim_, inv_factor_, cos_sin_); + rotaryEmbeddingDynamic<<>>( + q2b_, params.q_len, params.k_len, params.rope_theta, params.token_num, dim_, inv_factor_, cos_sin_); } else { int sess_len = 0; @@ -311,8 +339,13 @@ void RotaryEmbeddingV2::updateMapping(const RotaryEmbeddingV2Param& params) const size_t block = 256; const int tokens_per_block = block * iterms_per_thread; const size_t grid = (params.token_num + tokens_per_block - 1) / tokens_per_block; - computeQ2P - <<>>(q2b_, params.q_len, params.k_len, params.token_num, q2p_); + if (type_ == RopeType::kDynamic) { + computeQ2PDynamic<<>>(params.token_num, q2p_); + } + else { + computeQ2P + <<>>(q2b_, params.q_len, params.k_len, params.token_num, q2p_); + } } } diff --git a/src/turbomind/models/llama/rotary_emb.h b/src/turbomind/models/llama/rotary_emb.h index 8d04b29431..82338fffcd 100644 --- a/src/turbomind/models/llama/rotary_emb.h +++ b/src/turbomind/models/llama/rotary_emb.h @@ -13,7 +13,7 @@ struct RotaryEmbeddingV2Param { int* k_len; int* h_q_len; int* h_k_len; - int dc_size; + int dc_size{}; int batch_size; int token_num; };