Skip to content

Commit

Permalink
fix dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 10, 2024
1 parent 795c56f commit 22059d7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/turbomind/kernels/attention/test_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ int test_attention()
attn_param.rope.base = kRoPEBase;
attn_param.rope.dim = kRoPEDim;
attn_param.rope.factor = 1.0f;
auto rotary_emb = std::make_unique<RotaryEmbeddingV2<T>>(attn_param, kInputLen, nullptr, allocator.get());
auto rotary_emb = std::make_unique<RotaryEmbeddingV2<T>>(attn_param, 0, nullptr, allocator.get());

RotaryEmbeddingV2Param rotary_param;
rotary_param.rope_theta = rope_base.data().get();
Expand Down
59 changes: 46 additions & 13 deletions src/turbomind/models/llama/rotary_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,33 @@ __global__ void computeQ2P(const int* q2b, const int* q_len, const int* k_len, i
}
}

template<int iterms_per_thread>
__global__ void computeQ2PDynamic(int token_num, int* q2p)
{
Array<int, iterms_per_thread> p;

size_t thread_id = threadIdx.x + blockIdx.x * blockDim.x;
size_t index = thread_id * iterms_per_thread;
for (int i = 0; i < iterms_per_thread; ++i) {
int qi = index + i;
if (qi < token_num) {
p[i] = qi;
}
}
if (index < token_num) {
Store(&q2p[index], p);
}
}

template<typename T, int items_per_thread>
__global__ void rotaryEmbeddingDynamic(
const int* q2b, const int* q2p, const float* rope_base, int token_num, int dim, float factor, T* cos_sin)
__global__ void rotaryEmbeddingDynamic(const int* q2b,
const int* q_len,
const int* k_len,
const float* rope_base,
int token_num,
int dim,
float factor,
T* cos_sin)
{
int thread_id = threadIdx.x + blockIdx.x * blockDim.x;
int thread_per_tok = dim / items_per_thread;
Expand All @@ -121,9 +145,10 @@ __global__ void rotaryEmbeddingDynamic(
return;
}

int bid = q2b[qi] + 1;
float base = rope_base[bid - 1];
float ti = (float)q2p[qi];
int bid = q2b[qi] + 1;
int history_len = (k_len[bid] - k_len[bid - 1]) - (q_len[bid] - q_len[bid - 1]);
float base = rope_base[bid - 1];
float ti = history_len + qi - q_len[bid - 1];

Array<T, items_per_thread> cs;
float c, s;
Expand Down Expand Up @@ -233,9 +258,11 @@ RotaryEmbeddingV2<T>::RotaryEmbeddingV2(const AttentionParam& param,
break;
}

cos_sin_ = (T*)allocator_->reMalloc(cos_sin_, sizeof(T) * session_len * dim_);
allocateBuffer(session_len);
computeCache(session_len);
if (session_len) {
cos_sin_ = (T*)allocator_->reMalloc(cos_sin_, sizeof(T) * session_len * dim_);
allocateBuffer(session_len);
computeCache(session_len);
}
}

template<typename T>
Expand All @@ -248,10 +275,11 @@ void RotaryEmbeddingV2<T>::computeCache(int session_len)
switch (type_) {
case RopeType::kDefault:
case RopeType::kLinear:
case RopeType::kDynamic:
rotaryEmbedding<T, items_per_thread>
<<<grid, block, 0, stream_>>>(rope_base_, session_len, dim_, inv_factor_, CosSinDefault{}, cos_sin_);
break;
case RopeType::kDynamic:
break;
case RopeType::kLlama3:
rotaryEmbedding<T, items_per_thread><<<grid, block, 0, stream_>>>(
rope_base_, session_len, dim_, inv_factor_, CosSinLlama3{llama3_}, cos_sin_);
Expand All @@ -275,8 +303,8 @@ void RotaryEmbeddingV2<T>::updateCache(const RotaryEmbeddingV2Param& params)
const int items_per_thread = 8;
const int block = 256;
const int grid = (dim_ / items_per_thread * params.token_num + block - 1) / block;
rotaryEmbeddingDynamic<T, items_per_thread>
<<<grid, block, 0, stream_>>>(q2b_, q2p_, params.rope_theta, params.token_num, dim_, inv_factor_, cos_sin_);
rotaryEmbeddingDynamic<T, items_per_thread><<<grid, block, 0, stream_>>>(
q2b_, params.q_len, params.k_len, params.rope_theta, params.token_num, dim_, inv_factor_, cos_sin_);
}
else {
int sess_len = 0;
Expand Down Expand Up @@ -311,8 +339,13 @@ void RotaryEmbeddingV2<T>::updateMapping(const RotaryEmbeddingV2Param& params)
const size_t block = 256;
const int tokens_per_block = block * iterms_per_thread;
const size_t grid = (params.token_num + tokens_per_block - 1) / tokens_per_block;
computeQ2P<iterms_per_thread>
<<<grid, block, 0, stream_>>>(q2b_, params.q_len, params.k_len, params.token_num, q2p_);
if (type_ == RopeType::kDynamic) {
computeQ2PDynamic<iterms_per_thread><<<grid, block, 0, stream_>>>(params.token_num, q2p_);
}
else {
computeQ2P<iterms_per_thread>
<<<grid, block, 0, stream_>>>(q2b_, params.q_len, params.k_len, params.token_num, q2p_);
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/rotary_emb.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct RotaryEmbeddingV2Param {
int* k_len;
int* h_q_len;
int* h_k_len;
int dc_size;
int dc_size{};
int batch_size;
int token_num;
};
Expand Down

0 comments on commit 22059d7

Please sign in to comment.