Skip to content

Commit 03c892d

Browse files
authored
Merge pull request #108 from yirongjie/main
feat: add clear_kvcache && fix: BUG in quantize.
2 parents 1b1c3cc + 3fb20f9 commit 03c892d

File tree

11 files changed

+144
-69
lines changed

11 files changed

+144
-69
lines changed

examples/demo_elastic_llama.cpp

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,42 +41,40 @@ int main(int argc, char **argv) {
4141
std::cout << "[Q] " << in_str << std::endl;
4242
std::cout << "[A] " << std::flush;
4343
for (int step = 0; step < 100; step++) {
44-
// vecor<vector<int>> activate_dims = {{32*8,256}};
45-
// 32*8 is attn_head*attn_hidden_dim(e.g. llama:32*128); 256 is ffn_hidden_dim(e.g. llama:11008)
44+
float ratio = 1.0;//0.25; //0.5;
4645
vector<vector<int>> activate_dims = {
47-
// {(int)(32*128*0.5),(int)(11008*0.5)}, //0
48-
{-1,-1}, //0
49-
{-1,-1}, //1
50-
{-1,-1}, //2
51-
{-1,-1}, //3
52-
{-1,-1}, //4
53-
{-1,-1}, //5
54-
{-1,-1}, //6
55-
{-1,-1}, //7
56-
{-1,-1}, //8
57-
{-1,-1}, //9
58-
{-1,-1}, //10
59-
{-1,-1}, //11
60-
{-1,-1}, //12
61-
{-1,-1}, //13
62-
{-1,-1}, //14
63-
{-1,-1}, //15
64-
{-1,-1}, //16
65-
{-1,-1}, //17
66-
{-1,-1}, //18
67-
{-1,-1}, //19
68-
{-1,-1}, //20
69-
{-1,-1}, //21
70-
{-1,-1}, //22
71-
{-1,-1}, //23
72-
{-1,-1}, //24
73-
{-1,-1}, //25
74-
{-1,-1}, //26
75-
{-1,-1}, //27
76-
{-1,-1}, //28
77-
{-1,-1}, //29
78-
{-1,-1}, //30
79-
{-1,-1} //31
46+
{(int)(32*ratio),(int)(11008*ratio)}, //0
47+
{(int)(32*ratio),(int)(11008*ratio)}, //1
48+
{(int)(32*ratio),(int)(11008*ratio)}, //2
49+
{(int)(32*ratio),(int)(11008*ratio)}, //3
50+
{(int)(32*ratio),(int)(11008*ratio)}, //4
51+
{(int)(32*ratio),(int)(11008*ratio)}, //5
52+
{(int)(32*ratio),(int)(11008*ratio)}, //6
53+
{(int)(32*ratio),(int)(11008*ratio)}, //7
54+
{(int)(32*ratio),(int)(11008*ratio)}, //8
55+
{(int)(32*ratio),(int)(11008*ratio)}, //9
56+
{(int)(32*ratio),(int)(11008*ratio)}, //10
57+
{(int)(32*ratio),(int)(11008*ratio)}, //11
58+
{(int)(32*ratio),(int)(11008*ratio)}, //12
59+
{(int)(32*ratio),(int)(11008*ratio)}, //13
60+
{(int)(32*ratio),(int)(11008*ratio)}, //14
61+
{(int)(32*ratio),(int)(11008*ratio)}, //15
62+
{(int)(32*ratio),(int)(11008*ratio)}, //16
63+
{(int)(32*ratio),(int)(11008*ratio)}, //17
64+
{(int)(32*ratio),(int)(11008*ratio)}, //18
65+
{(int)(32*ratio),(int)(11008*ratio)}, //19
66+
{(int)(32*ratio),(int)(11008*ratio)}, //20
67+
{(int)(32*ratio),(int)(11008*ratio)}, //21
68+
{(int)(32*ratio),(int)(11008*ratio)}, //22
69+
{(int)(32*ratio),(int)(11008*ratio)}, //23
70+
{(int)(32*ratio),(int)(11008*ratio)}, //24
71+
{(int)(32*ratio),(int)(11008*ratio)}, //25
72+
{(int)(32*ratio),(int)(11008*ratio)}, //26
73+
{(int)(32*ratio),(int)(11008*ratio)}, //27
74+
{(int)(32*ratio),(int)(11008*ratio)}, //28
75+
{(int)(32*ratio),(int)(11008*ratio)}, //29
76+
{(int)(32*ratio),(int)(11008*ratio)}, //30
77+
{(int)(32*ratio),(int)(11008*ratio)} //31
8078
};
8179
auto result = model({input_tensor}, activate_dims);
8280
auto outputs = tokenizer.detokenize(result[0]);
@@ -89,6 +87,7 @@ int main(int argc, char **argv) {
8987
chatPostProcessing(out_token, input_tensor, {});
9088
}
9189
printf("\n");
90+
model.clear_kvcache();
9291
}
9392

9493
return 0;

examples/demo_llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ int main(int argc, char **argv) {
5252
chatPostProcessing(out_token, input_tensor, {});
5353
}
5454
printf("\n");
55+
model.clear_kvcache();
5556
model.profiling();
5657
}
5758

src/Layer.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,9 @@ class KVCache final : public Layer {
753753
int getCacheSeqLen(){
754754
return op_->getCacheSeqLen();
755755
}
756+
void clearCache(){
757+
return op_->clearCache();
758+
}
756759
};
757760

758761
class LayerNorm final : public Layer {

src/Op.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ class Op {
120120
std::cout << "only for KVCache" << std::endl;
121121
return -1;
122122
}
123+
virtual void clearCache(){
124+
assert(type_ == OpType::KVCACHE);
125+
std::cout << "only for KVCache" << std::endl;
126+
}
123127

124128
private:
125129
Backend *backend_;

src/backends/cpu/CPUKVCache.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class CPUKVCache final : public Op {
2323
int getCacheSeqLen() override{
2424
return cache_seq_len_;
2525
}
26+
void clearCache() override{
27+
cache_seq_len_ = 0 ;
28+
}
2629

2730
private:
2831
int thread_count = 4;

src/backends/cpu/compute/Matmul.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,6 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
229229
to->setBackend(src0->backend());
230230
to->setDtype(vec_dot_type);
231231
to->alloc();
232-
// void *row_src = src0->rawHostPtr();
233-
// void *row_dst = to->rawHostPtr();
234-
// auto row_size_src = row_size(src0_dtype, src0->dimension());
235-
// auto row_size_dst = row_size(vec_dot_type, to->dimension());
236-
// auto n_row = src0->batch() * src0->head() * src0->sequence();
237-
// auto n_ele = src0->dimension();
238-
// #pragma omp parallel for num_threads(thread_count)
239-
// for(int i = 0;i < n_row;i++){ // copy row by row
240-
// auto row1 = (char *)row_src + i * row_size_src;
241-
// auto row2 = (char *)row_dst + i * row_size_dst;
242-
// x_to_vec_dot_type(reinterpret_cast<const float *>(row1), row2, n_ele);
243-
// }
244232
int64_t i_processed = 0;
245233
if (from_float_to_mat && gemv && dst->masterTensor()==nullptr){
246234
for (int b = 0; b < src0->batch(); b++) {
@@ -272,7 +260,7 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
272260
}
273261

274262
#ifdef LLAMAFILE_SGEMM
275-
if (check_llamafile_sgemm(N, M, K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
263+
if (check_llamafile_sgemm(N, M, K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&dst->dtypeAt(0,0,0,0) == MLLM_TYPE_F32){
276264
const int ld_src1 = src1->sequence_skip_dim();
277265
const int ld_src0 = src0->sequence_skip_dim();
278266
const int ld_dst = dst->sequence_skip_dim();
@@ -294,11 +282,23 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
294282
}
295283
}
296284
}
285+
if(support_bias){
286+
#pragma omp parallel for collapse(4) num_threads(thread_count)
287+
for (int b = 0; b < dst->batch(); b++) {
288+
for (int h = 0; h < dst->head(); h++) {
289+
for (int m = 0; m < M; m++) {
290+
for (int n = 0; n < N; n++) {
291+
*dst->ptrAt<float>(b, h, m, n) += bias->dataAt<float>(0, 0, 0, n);
292+
}
293+
}
294+
}
295+
}
296+
}
297297
return MLLM_NO_ERROR;
298298
}
299299
#endif
300300

301-
if(gemv&&!support_bias){
301+
if(gemv&&dst->dtypeAt(0,0,0,0) == MLLM_TYPE_F32){
302302
int nth=thread_count;
303303
#pragma omp parallel for collapse(1) num_threads(thread_count)
304304
for (int ith = 0; ith < nth; ith++){
@@ -318,6 +318,18 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te
318318
1, N/nth);
319319
}
320320
}
321+
if(support_bias){
322+
#pragma omp parallel for collapse(4) num_threads(thread_count)
323+
for (int b = 0; b < dst->batch(); b++) {
324+
for (int h = 0; h < dst->head(); h++) {
325+
for (int m = 0; m < M; m++) {
326+
for (int n = 0; n < N; n++) {
327+
*dst->ptrAt<float>(b, h, m, n) += bias->dataAt<float>(0, 0, 0, n);
328+
}
329+
}
330+
}
331+
}
332+
}
321333
return MLLM_NO_ERROR;
322334
}
323335

@@ -743,15 +755,15 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
743755
}
744756

745757
#ifdef LLAMAFILE_SGEMM
746-
if (check_llamafile_sgemm(N, M, use_K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
758+
if (check_llamafile_sgemm(use_N, M, use_K/blck_size(src1->dtype()),src1->dtype(),src0->dtype(),dst->dtype())&&!support_bias){
747759
const int ld_src1 = src1->sequence_skip_dim();
748760
const int ld_src0 = src0->sequence_skip_dim();
749761
const int ld_dst = dst->sequence_skip_dim();
750762
#pragma omp parallel for collapse(3) num_threads(thread_count)
751763
for (int64_t b = 0; b < dst->batch(); b++){
752764
for (int64_t h = 0; h < dst->head(); h++){
753765
for (int id = 0; id < thread_count; id++){
754-
llamafile_sgemm(N, M, use_K/blck_size(src1->dtype()),
766+
llamafile_sgemm(use_N, M, use_K/blck_size(src1->dtype()),
755767
(char *)src1->rawHostPtr() + src1->offset(b, h, 0, 0) * src1_type_size / src1_blck_size,
756768
ld_src1 / src1_blck_size,
757769
(char *)src0->rawHostPtr() + src0->offset(b, h, 0, 0) * src0_type_size / src0_blck_size,
@@ -774,19 +786,19 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_
774786
#pragma omp parallel for collapse(1) num_threads(thread_count)
775787
for (int ith = 0; ith < nth; ith++){
776788
int64_t i_processed = 0;
777-
int64_t seq_start = (ith * N) / nth;
778-
int64_t seq_end = ((ith + 1) * N) / nth;
789+
int64_t seq_start = (ith * use_N) / nth;
790+
int64_t seq_end = ((ith + 1) * use_N) / nth;
779791
if (gemm && (M > 3) && dst->masterTensor()==nullptr) {
780792
gemm(use_K, dst->hostPtr<float>() + dst->offset(0, 0, 0, seq_start),
781-
N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
782-
(char *)src0->rawHostPtr(), M - M % 4, N/nth);
793+
use_N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
794+
(char *)src0->rawHostPtr(), M - M % 4, use_N/nth);
783795
i_processed = M - M % 4;
784796
}
785797
for (int iter = i_processed; iter < M; iter++) { //M-M%4
786798
gemv(use_K, dst->hostPtr<float>() + dst->offset(0, 0, iter, seq_start),
787-
N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
799+
use_N, (char *)src1->rawHostPtr()+ src1->offset(0, 0, seq_start, 0) * src1_type_size / src1_blck_size,
788800
(char *)src0->rawHostPtr() + src0->offset(0, 0, iter, 0) * src0_type_size / src0_blck_size,
789-
1, N/nth);
801+
1, use_N/nth);
790802
}
791803
}
792804
return MLLM_NO_ERROR;

src/models/llama/modeling_elastic_llama.hpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ElasticMultiHeadAttention final : public Module {
3232
ElasticMultiHeadAttention(int hidden_dim, int head_size,int kv_head_size, int attn_hidden_dim,
3333
RoPEType RoPE_type, int cache_limit, bool do_mask, bool bias,
3434
const TransformerNameConfig &names, const string &base_name) {
35+
assert(kv_head_size_ == head_size_);
3536
attn_hidden_dim_ = attn_hidden_dim;
3637
head_size_ = head_size;
3738
kv_head_size_ = kv_head_size;
@@ -51,16 +52,16 @@ class ElasticMultiHeadAttention final : public Module {
5152
o_proj = ElasticLinear(head_size * attn_hidden_dim, hidden_dim, bias, base_name + names._o_proj_name);
5253
}
5354
vector<Tensor> Forward(vector<Tensor> inputs, vector<std::any> args) override {
54-
vector<int> activate_dims = std::any_cast<vector<int>>(args[0]);
55-
int activate_dim = activate_dims[0];
56-
int activate_hidden_dim = (activate_dim==-1)? attn_hidden_dim_: (activate_dim/head_size_);
55+
vector<int> activate_head_dims = std::any_cast<vector<int>>(args[0]);
56+
int activate_head_dim = activate_head_dims[0];
57+
activate_head_dim = (activate_head_dim==-1)? kv_head_size_: (activate_head_dim);
5758
Tensor q, k, v;
58-
q = q_proj(inputs[0], -1, activate_dim);
59-
k = k_proj(inputs[1], -1, activate_dim);
60-
v = v_proj(inputs[2], -1, activate_dim);
61-
q = q.view(-1, head_size_, -1, activate_hidden_dim);
62-
k = k.view(-1, kv_head_size_, -1, activate_hidden_dim);
63-
v = v.view(-1, kv_head_size_, -1, activate_hidden_dim);
59+
q = q_proj(inputs[0], -1, activate_head_dim*attn_hidden_dim_);
60+
k = k_proj(inputs[1], -1, activate_head_dim*attn_hidden_dim_);
61+
v = v_proj(inputs[2], -1, activate_head_dim*attn_hidden_dim_);
62+
q = q.view(-1, activate_head_dim, -1, attn_hidden_dim_);
63+
k = k.view(-1, activate_head_dim, -1, attn_hidden_dim_);
64+
v = v.view(-1, activate_head_dim, -1, attn_hidden_dim_);
6465
if (q_rope.ready() && k_rope.ready()) {
6566
q = q_rope(q);
6667
k = k_rope(k);
@@ -71,17 +72,20 @@ class ElasticMultiHeadAttention final : public Module {
7172
}
7273
k = k.transpose(SEQUENCE, DIMENSION);
7374
auto qk = Tensor::mm(q, k);
74-
qk = qk / std::sqrt(activate_hidden_dim);//attn_hidden_dim_
75+
qk = qk / std::sqrt(attn_hidden_dim_);//attn_hidden_dim_
7576
if (k_cache.ready() && v_cache.ready()) {
7677
qk = softmax(qk, k_cache.getCacheSeqLen());
7778
}else{
7879
qk = softmax(qk);
7980
}
8081
auto o = Tensor::mm(qk, v);
81-
o = o.view(-1, 1, -1, activate_hidden_dim * head_size_);
82-
o = o_proj(o, activate_dim, -1);
82+
o = o.view(-1, 1, -1, attn_hidden_dim_ * activate_head_dim);
83+
o = o_proj(o, activate_head_dim*attn_hidden_dim_, -1);
8384
return {o};
8485
}
86+
vector<KVCache*> get_cache() {
87+
return {&k_cache,&v_cache};
88+
}
8589
};
8690

8791
class ElasticLLaMAMLP final : public Module {
@@ -137,6 +141,9 @@ class ElasticLLaMABlock final : public Module {
137141
x = x + tmp;
138142
return {x};
139143
}
144+
ElasticMultiHeadAttention& get_attention() {
145+
return attention;
146+
}
140147
};
141148

142149
class ElasticLLaMAModel final : public Module {
@@ -170,6 +177,15 @@ class ElasticLLaMAModel final : public Module {
170177
x = lm_head(x);
171178
return {x};
172179
}
180+
181+
void clear_kvcache() {
182+
for (auto &block : blocks) {
183+
auto kvcahce =block.get_attention().get_cache();
184+
for (auto &cache : kvcahce) {
185+
cache->clearCache();
186+
}
187+
}
188+
}
173189
};
174190

175191
#endif // MODELING_LLAMA_HPP

src/models/llama/modeling_llama.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ class LLaMABlock final : public Module {
6060
x = x + tmp;
6161
return {x};
6262
}
63+
64+
MultiHeadAttention& get_attention() {
65+
return attention;
66+
}
6367
};
6468

6569
class LLaMAModel final : public Module {
@@ -90,6 +94,15 @@ class LLaMAModel final : public Module {
9094
x = lm_head(x);
9195
return {x};
9296
}
97+
98+
void clear_kvcache() {
99+
for (auto &block : blocks) {
100+
auto kvcahce =block.get_attention().get_cache();
101+
for (auto &cache : kvcahce) {
102+
cache->clearCache();
103+
}
104+
}
105+
}
93106
};
94107

95108
#endif // MODELING_LLAMA_HPP

src/models/qwen/modeling_qwen.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ class QWenAttention final : public Module {
101101
atten_output = o_proj(atten_output);
102102
return {atten_output};
103103
}
104+
vector<KVCache*> get_cache() {
105+
return {&k_cache,&v_cache};
106+
}
104107

105108
private:
106109
int hidden_size;
@@ -140,6 +143,9 @@ class QWenDecoder final : public Module {
140143
x = x + tmp;
141144
return {x};
142145
}
146+
QWenAttention& get_attention() {
147+
return self_atten;
148+
}
143149

144150
private:
145151
QWenAttention self_atten;
@@ -165,6 +171,14 @@ class QWenModel final : public Module {
165171
x = norm(x);
166172
return {x};
167173
}
174+
void clear_kvcache() {
175+
for (auto &block : blocks) {
176+
auto kvcahce =block.get_attention().get_cache();
177+
for (auto &cache : kvcahce) {
178+
cache->clearCache();
179+
}
180+
}
181+
}
168182

169183
private:
170184
std::vector<QWenDecoder> blocks;
@@ -201,6 +215,9 @@ class QWenForCausalLM final : public Module {
201215
}
202216
return {outputs};
203217
}
218+
void clear_kvcache() {
219+
model.clear_kvcache();
220+
}
204221

205222
private:
206223
int hidden_size;

0 commit comments

Comments
 (0)