Skip to content

Commit

Permalink
Merge pull request #108 from yirongjie/main
Browse files Browse the repository at this point in the history
feat: add clear_kvcache && fix: BUG in quantize.
  • Loading branch information
yirongjie authored Aug 3, 2024
2 parents 1b1c3cc + 3fb20f9 commit 03c892d
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 69 deletions.
69 changes: 34 additions & 35 deletions examples/demo_elastic_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,42 +41,40 @@ int main(int argc, char **argv) {
std::cout << "[Q] " << in_str << std::endl;
std::cout << "[A] " << std::flush;
for (int step = 0; step < 100; step++) {
// vecor<vector<int>> activate_dims = {{32*8,256}};
// 32*8 is attn_head*attn_hidden_dim(e.g. llama:32*128); 256 is ffn_hidden_dim(e.g. llama:11008)
float ratio = 1.0;//0.25; //0.5;
vector<vector<int>> activate_dims = {
// {(int)(32*128*0.5),(int)(11008*0.5)}, //0
{-1,-1}, //0
{-1,-1}, //1
{-1,-1}, //2
{-1,-1}, //3
{-1,-1}, //4
{-1,-1}, //5
{-1,-1}, //6
{-1,-1}, //7
{-1,-1}, //8
{-1,-1}, //9
{-1,-1}, //10
{-1,-1}, //11
{-1,-1}, //12
{-1,-1}, //13
{-1,-1}, //14
{-1,-1}, //15
{-1,-1}, //16
{-1,-1}, //17
{-1,-1}, //18
{-1,-1}, //19
{-1,-1}, //20
{-1,-1}, //21
{-1,-1}, //22
{-1,-1}, //23
{-1,-1}, //24
{-1,-1}, //25
{-1,-1}, //26
{-1,-1}, //27
{-1,-1}, //28
{-1,-1}, //29
{-1,-1}, //30
{-1,-1} //31
{(int)(32*ratio),(int)(11008*ratio)}, //0
{(int)(32*ratio),(int)(11008*ratio)}, //1
{(int)(32*ratio),(int)(11008*ratio)}, //2
{(int)(32*ratio),(int)(11008*ratio)}, //3
{(int)(32*ratio),(int)(11008*ratio)}, //4
{(int)(32*ratio),(int)(11008*ratio)}, //5
{(int)(32*ratio),(int)(11008*ratio)}, //6
{(int)(32*ratio),(int)(11008*ratio)}, //7
{(int)(32*ratio),(int)(11008*ratio)}, //8
{(int)(32*ratio),(int)(11008*ratio)}, //9
{(int)(32*ratio),(int)(11008*ratio)}, //10
{(int)(32*ratio),(int)(11008*ratio)}, //11
{(int)(32*ratio),(int)(11008*ratio)}, //12
{(int)(32*ratio),(int)(11008*ratio)}, //13
{(int)(32*ratio),(int)(11008*ratio)}, //14
{(int)(32*ratio),(int)(11008*ratio)}, //15
{(int)(32*ratio),(int)(11008*ratio)}, //16
{(int)(32*ratio),(int)(11008*ratio)}, //17
{(int)(32*ratio),(int)(11008*ratio)}, //18
{(int)(32*ratio),(int)(11008*ratio)}, //19
{(int)(32*ratio),(int)(11008*ratio)}, //20
{(int)(32*ratio),(int)(11008*ratio)}, //21
{(int)(32*ratio),(int)(11008*ratio)}, //22
{(int)(32*ratio),(int)(11008*ratio)}, //23
{(int)(32*ratio),(int)(11008*ratio)}, //24
{(int)(32*ratio),(int)(11008*ratio)}, //25
{(int)(32*ratio),(int)(11008*ratio)}, //26
{(int)(32*ratio),(int)(11008*ratio)}, //27
{(int)(32*ratio),(int)(11008*ratio)}, //28
{(int)(32*ratio),(int)(11008*ratio)}, //29
{(int)(32*ratio),(int)(11008*ratio)}, //30
{(int)(32*ratio),(int)(11008*ratio)} //31
};
auto result = model({input_tensor}, activate_dims);
auto outputs = tokenizer.detokenize(result[0]);
Expand All @@ -89,6 +87,7 @@ int main(int argc, char **argv) {
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.clear_kvcache();
}

return 0;
Expand Down
1 change: 1 addition & 0 deletions examples/demo_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ int main(int argc, char **argv) {
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.clear_kvcache();
model.profiling();
}

Expand Down
3 changes: 3 additions & 0 deletions src/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,9 @@ class KVCache final : public Layer {
int getCacheSeqLen(){
return op_->getCacheSeqLen();
}
void clearCache(){
return op_->clearCache();
}
};

class LayerNorm final : public Layer {
Expand Down
4 changes: 4 additions & 0 deletions src/Op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class Op {
std::cout << "only for KVCache" << std::endl;
return -1;
}
virtual void clearCache(){
assert(type_ == OpType::KVCACHE);
std::cout << "only for KVCache" << std::endl;
}

private:
Backend *backend_;
Expand Down
3 changes: 3 additions & 0 deletions src/backends/cpu/CPUKVCache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class CPUKVCache final : public Op {
int getCacheSeqLen() override{
return cache_seq_len_;
}
void clearCache() override{
cache_seq_len_ = 0 ;
}

private:
int thread_count = 4;
Expand Down
56 changes: 34 additions & 22 deletions src/backends/cpu/compute/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,18 +229,6 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
to->setBackend(src0->backend());
to->setDtype(vec_dot_type);
to->alloc();
// void *row_src = src0->rawHostPtr();
// void *row_dst = to->rawHostPtr();
// auto row_size_src = row_size(src0_dtype, src0->dimension());
// auto row_size_dst = row_size(vec_dot_type, to->dimension());
// auto n_row = src0->batch() * src0->head() * src0->sequence();
// auto n_ele = src0->dimension();
// #pragma omp parallel for num_threads(thread_count)
// for(int i = 0;i < n_row;i++){ // copy row by row
// auto row1 = (char *)row_src + i * row_size_src;
// auto row2 = (char *)row_dst + i * row_size_dst;
// x_to_vec_dot_type(reinterpret_cast<const float *>(row1), row2, n_ele);
// }
int64_t i_processed = 0;
if (from_float_to_mat && gemv && dst->masterTensor()==nullptr){
for (int b = 0; b < src0->batch(); b++) {
Expand Down Expand Up @@ -272,7 +260,7 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
}

#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(N, M, K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
if (check_llamafile_sgemm(N, M, K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&dst->dtypeAt(0,0,0,0) == MLLM_TYPE_F32){
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
Expand All @@ -294,11 +282,23 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
}
}
}
if(support_bias){
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int b = 0; b < dst->batch(); b++) {
for (int h = 0; h < dst->head(); h++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
*dst->ptrAt<float>(b, h, m, n) += bias->dataAt<float>(0, 0, 0, n);
}
}
}
}
}
return MLLM_NO_ERROR;
}
#endif

if(gemv&&!support_bias){
if(gemv&&dst->dtypeAt(0,0,0,0) == MLLM_TYPE_F32){
int nth=thread_count;
#pragma omp parallel for collapse(1) num_threads(thread_count)
for (int ith = 0; ith < nth; ith++){
Expand All @@ -318,6 +318,18 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
1, N/nth);
}
}
if(support_bias){
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int b = 0; b < dst->batch(); b++) {
for (int h = 0; h < dst->head(); h++) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
*dst->ptrAt<float>(b, h, m, n) += bias->dataAt<float>(0, 0, 0, n);
}
}
}
}
}
return MLLM_NO_ERROR;
}

Expand Down Expand Up @@ -743,15 +755,15 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
}

#ifdef LLAMAFILE_SGEMM
if (check_llamafile_sgemm(N, M, use_K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
if (check_llamafile_sgemm(use_N, M, use_K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
const int ld_src1 = src1->sequence_skip_dim();
const int ld_src0 = src0->sequence_skip_dim();
const int ld_dst = dst->sequence_skip_dim();
#pragma omp parallel for collapse(3) num_threads(thread_count)
for (int64_t b = 0; b < dst->batch(); b++){
for (int64_t h = 0; h < dst->head(); h++){
for (int id = 0; id < thread_count; id++){
llamafile_sgemm(N, M, use_K/blck_size(src1->dtype()),
llamafile_sgemm(use_N, M, use_K/blck_size(src1->dtype()),
(char *)src1->rawHostPtr() + src1->offset(b, h, 0, 0) * src1_type_size / src1_blck_size,
ld_src1 / src1_blck_size,
(char *)src0->rawHostPtr() + src0->offset(b, h, 0, 0) * src0_type_size / src0_blck_size,
Expand All @@ -774,19 +786,19 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
#pragma omp parallel for collapse(1) num_threads(thread_count)
for (int ith = 0; ith < nth; ith++){
int64_t i_processed = 0;
int64_t seq_start = (ith * N) / nth;
int64_t seq_end = ((ith + 1) * N) / nth;
int64_t seq_start = (ith * use_N) / nth;
int64_t seq_end = ((ith + 1) * use_N) / nth;
if (gemm && (M > 3) && dst->masterTensor()==nullptr) {
gemm(use_K, dst->hostPtr<float>() + dst->offset(0, 0, 0, seq_start),
N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
(char *)src0->rawHostPtr(), M - M % 4, N/nth);
use_N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
(char *)src0->rawHostPtr(), M - M % 4, use_N/nth);
i_processed = M - M % 4;
}
for (int iter = i_processed; iter < M; iter++) { //M-M%4
gemv(use_K, dst->hostPtr<float>() + dst->offset(0, 0, iter, seq_start),
N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
use_N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
(char *)src0->rawHostPtr() + src0->offset(0, 0, iter, 0) * src0_type_size / src0_blck_size,
1, N/nth);
1, use_N/nth);
}
}
return MLLM_NO_ERROR;
Expand Down
40 changes: 28 additions & 12 deletions src/models/llama/modeling_elastic_llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ElasticMultiHeadAttention final : public Module {
ElasticMultiHeadAttention(int hidden_dim, int head_size,int kv_head_size, int attn_hidden_dim,
RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias,
const TransformerNameConfig &names, const string &base_name) {
assert(kv_head_size_ == head_size_);
attn_hidden_dim_ = attn_hidden_dim;
head_size_ = head_size;
kv_head_size_ = kv_head_size;
Expand All @@ -51,16 +52,16 @@ class ElasticMultiHeadAttention final : public Module {
o_proj = ElasticLinear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name);
}
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
vector<int> activate_dims = std::any_cast<vector<int>>(args[0]);
int activate_dim = activate_dims[0];
int activate_hidden_dim = (activate_dim==-1)? attn_hidden_dim_: (activate_dim/head_size_);
vector<int> activate_head_dims = std::any_cast<vector<int>>(args[0]);
int activate_head_dim = activate_head_dims[0];
activate_head_dim = (activate_head_dim==-1)? kv_head_size_: (activate_head_dim);
Tensor q, k, v;
q = q_proj(inputs[0], -1, activate_dim);
k = k_proj(inputs[1], -1, activate_dim);
v = v_proj(inputs[2], -1, activate_dim);
q = q.view(-1, head_size_, -1, activate_hidden_dim);
k = k.view(-1, kv_head_size_, -1, activate_hidden_dim);
v = v.view(-1, kv_head_size_, -1, activate_hidden_dim);
q = q_proj(inputs[0], -1, activate_head_dim*attn_hidden_dim_);
k = k_proj(inputs[1], -1, activate_head_dim*attn_hidden_dim_);
v = v_proj(inputs[2], -1, activate_head_dim*attn_hidden_dim_);
q = q.view(-1, activate_head_dim, -1, attn_hidden_dim_);
k = k.view(-1, activate_head_dim, -1, attn_hidden_dim_);
v = v.view(-1, activate_head_dim, -1, attn_hidden_dim_);
if (q_rope.ready() && k_rope.ready()) {
q = q_rope(q);
k = k_rope(k);
Expand All @@ -71,17 +72,20 @@ class ElasticMultiHeadAttention final : public Module {
}
k = k.transpose(SEQUENCE, DIMENSION);
auto qk = Tensor::mm(q, k);
qk = qk / std::sqrt(activate_hidden_dim);//attn_hidden_dim_
qk = qk / std::sqrt(attn_hidden_dim_);//attn_hidden_dim_
if (k_cache.ready() && v_cache.ready()) {
qk = softmax(qk, k_cache.getCacheSeqLen());
}else{
qk = softmax(qk);
}
auto o = Tensor::mm(qk, v);
o = o.view(-1, 1, -1, activate_hidden_dim * head_size_);
o = o_proj(o, activate_dim, -1);
o = o.view(-1, 1, -1, attn_hidden_dim_ * activate_head_dim);
o = o_proj(o, activate_head_dim*attn_hidden_dim_, -1);
return {o};
}
vector<KVCache*> get_cache() {
return {&k_cache,&v_cache};
}
};

class ElasticLLaMAMLP final : public Module {
Expand Down Expand Up @@ -137,6 +141,9 @@ class ElasticLLaMABlock final : public Module {
x = x + tmp;
return {x};
}
ElasticMultiHeadAttention& get_attention() {
return attention;
}
};

class ElasticLLaMAModel final : public Module {
Expand Down Expand Up @@ -170,6 +177,15 @@ class ElasticLLaMAModel final : public Module {
x = lm_head(x);
return {x};
}

void clear_kvcache() {
for (auto &block : blocks) {
auto kvcahce =block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
}
}
};

#endif // MODELING_LLAMA_HPP
13 changes: 13 additions & 0 deletions src/models/llama/modeling_llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class LLaMABlock final : public Module {
x = x + tmp;
return {x};
}

MultiHeadAttention& get_attention() {
return attention;
}
};

class LLaMAModel final : public Module {
Expand Down Expand Up @@ -90,6 +94,15 @@ class LLaMAModel final : public Module {
x = lm_head(x);
return {x};
}

void clear_kvcache() {
for (auto &block : blocks) {
auto kvcahce =block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
}
}
};

#endif // MODELING_LLAMA_HPP
17 changes: 17 additions & 0 deletions src/models/qwen/modeling_qwen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ class QWenAttention final : public Module {
atten_output = o_proj(atten_output);
return {atten_output};
}
vector<KVCache*> get_cache() {
return {&k_cache,&v_cache};
}

private:
int hidden_size;
Expand Down Expand Up @@ -140,6 +143,9 @@ class QWenDecoder final : public Module {
x = x + tmp;
return {x};
}
QWenAttention& get_attention() {
return self_atten;
}

private:
QWenAttention self_atten;
Expand All @@ -165,6 +171,14 @@ class QWenModel final : public Module {
x = norm(x);
return {x};
}
void clear_kvcache() {
for (auto &block : blocks) {
auto kvcahce =block.get_attention().get_cache();
for (auto &cache : kvcahce) {
cache->clearCache();
}
}
}

private:
std::vector<QWenDecoder> blocks;
Expand Down Expand Up @@ -201,6 +215,9 @@ class QWenForCausalLM final : public Module {
}
return {outputs};
}
void clear_kvcache() {
model.clear_kvcache();
}

private:
int hidden_size;
Expand Down
Loading

0 comments on commit 03c892d

Please sign in to comment.