diff --git a/src/Module.hpp b/src/Module.hpp index c0470e01..a696220a 100644 --- a/src/Module.hpp +++ b/src/Module.hpp @@ -145,7 +145,7 @@ class Module { if(input.batch() == 0){ Tensor::gph_[input.name()] = input; } - if(input.sequence()!=1 && !last_shape_bshd_.empty()){ + if(inputs[0].sequence()!=1 && !last_shape_bshd_.empty()){ // if LLM/VLLM model, the `need_setup` should be `true` if(input.batch() == last_shape_bshd_[i][0] & input.sequence() == last_shape_bshd_[i][1] & diff --git a/src/backends/cpu/compute/Matmul.cpp b/src/backends/cpu/compute/Matmul.cpp index 3382721f..d5b49849 100644 --- a/src/backends/cpu/compute/Matmul.cpp +++ b/src/backends/cpu/compute/Matmul.cpp @@ -176,16 +176,13 @@ ErrorCode mat_mul_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bia Tensor *src0_cal = src0; Tensor *src1_cal = src1; const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int b = 0; b < src0->batch(); b++) { for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; for (int m = 0; m < M; m++) { - const int num_blocks = N / blck_0; - const int remainder = N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { int s_1, d_1; int s_0, d_0; if (!transpose0 && transpose1) { @@ -197,7 +194,7 @@ ErrorCode mat_mul_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bia } if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { vec_dot_fp32(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); @@ -205,7 +202,7 @@ ErrorCode mat_mul_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bia }else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { float tmp = 0; vec_dot_fp32(K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if (support_bias) { *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); @@ -228,9 +225,9 @@ ErrorCode mat_mul_fp32_fp16(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo src0_qf16.setBackend(src0_->backend()); src0_qf16.setDtype(MLLM_TYPE_F16); src0_qf16.alloc(); +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int b = 0; b < src0_->batch(); b++) { for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) for (int s = 0; s < src0_->sequence(); s++) { mllm_fp32_to_fp16_row(src0_->hostPtr() + src0_->offset(b, h, s, 0), src0_qf16.hostPtr() + src0_qf16.offset(b, h, s, 0), @@ -239,30 +236,19 @@ ErrorCode mat_mul_fp32_fp16(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo } } auto *src0 = &src0_qf16; - // for(int b=0; bdimension(); b++) { - // std::cout<ptrAt(0, 0, 0, b))<<" "; - // } - // std::cout<dimension(); b++) { - // std::cout<ptrAt(0, 0, 0, b))<<" "; - // } - // std::cout<dimension() : src0->sequence(); const int K = transpose0 ? src0->sequence() : src0->dimension(); const int N = transpose1 ? src1->sequence() : src1->dimension(); Tensor *src0_cal = src0; Tensor *src1_cal = src1; const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int b = 0; b < src0->batch(); b++) { for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; for (int m = 0; m < M; m++) { - const int num_blocks = N / blck_0; - const int remainder = N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { int s_1, d_1; int s_0, d_0; if (!transpose0 && transpose1) { @@ -273,7 +259,7 @@ ErrorCode mat_mul_fp32_fp16(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; } vec_dot_fp16(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, s_1, d_1), + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); @@ -294,9 +280,9 @@ ErrorCode mat_mul_fp32_q4_0(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo src0_q8.setDtype(MLLM_TYPE_Q8_0); src0_q8.alloc(); if (src0_->dimension() % QK8_0 == 0) { +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int b = 0; b < src0_->batch(); b++) { for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) for (int s = 0; s < src0_->sequence(); s++) { quantize_row_q8_0(src0_->hostPtr() + src0_->offset(b, h, s, 0), src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK8_0, @@ -316,18 +302,15 @@ ErrorCode mat_mul_fp32_q4_0(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo Tensor *src0_cal = src0; Tensor *src1_cal = src1; const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int b = 0; b < src0->batch(); b++) { for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; for (int m = 0; m < M; m++) { - int num_blocks = N / blck_0; - int remainder = N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { vec_dot_q4_0_q8_0(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK4_0, + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, n, 0) / QK4_0, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK8_0); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); @@ -348,9 +331,9 @@ ErrorCode mat_mul_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo src0_q8.setDtype(MLLM_TYPE_Q8_K); src0_q8.alloc(); if (src0_->dimension() % QK_K == 0) { +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int b = 0; b < src0_->batch(); b++) { for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) for (int s = 0; s < src0_->sequence(); s++) { quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK_K, @@ -370,20 +353,16 @@ ErrorCode mat_mul_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo Tensor *src0_cal = src0; Tensor *src1_cal = src1; const int64_t blck_0 = 16; - + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int b = 0; b < src0->batch(); b++) { for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; for (int m = 0; m < M; m++) { - int num_blocks = N / blck_0; - int remainder = N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { vec_dot_q4_K_q8_K(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); @@ -391,7 +370,7 @@ ErrorCode mat_mul_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo } else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { float tmp = 0; vec_dot_q4_K_q8_K(K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); @@ -415,9 +394,9 @@ ErrorCode mat_mul_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo src0_q8.setDtype(MLLM_TYPE_Q8_K); src0_q8.alloc(); if (src0_->dimension() % QK_K == 0) { +#pragma omp parallel for collapse(3) num_threads(thread_count) for (int b = 0; b < src0_->batch(); b++) { for (int h = 0; h < src0_->head(); h++) { -#pragma omp parallel for num_threads(thread_count) for (int s = 0; s < src0_->sequence(); s++) { quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / QK_K, @@ -437,19 +416,16 @@ ErrorCode mat_mul_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo Tensor *src0_cal = src0; Tensor *src1_cal = src1; const int64_t blck_0 = 16; + int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; +#pragma omp parallel for collapse(4) num_threads(thread_count) for (int b = 0; b < src0->batch(); b++) { for (int h = 0; h < src0->head(); h++) { - const int b_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : b; - const int h_1 = (src1->batch() == 1 && src1->head() == 1) ? 0 : h; for (int m = 0; m < M; m++) { - int num_blocks = N / blck_0; - int remainder = N % blck_0; -#pragma omp parallel for num_threads(thread_count) - for (int block = 0; block < num_blocks + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < num_blocks * blck_0 + remainder; n++) { + for (int block = 0; block < N / blck_0 + 1; block++) { + for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F32) { vec_dot_q6_K_q8_K(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); @@ -457,7 +433,7 @@ ErrorCode mat_mul_fp32_q6_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool suppo } else if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F16) { float tmp = 0; vec_dot_q6_K_q8_K(K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b_1, h_1, n, 0) / QK_K, + src1_cal->hostPtr() + src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK_K); if (support_bias) { diff --git a/src/models/qwen/configuration_qwen.hpp b/src/models/qwen/configuration_qwen.hpp index 8c14f8b6..518a988a 100644 --- a/src/models/qwen/configuration_qwen.hpp +++ b/src/models/qwen/configuration_qwen.hpp @@ -102,6 +102,20 @@ struct QWenConfig { sliding_window = 32768; vocab_size = 151936; tie_embedding_words = true; + } else if (billionsType == "1.8b") { + attention_dropout = 0.0; + std::string hidden_act = "silu"; + hidden_size = 2048; + intermediate_size = 5504; + max_position_embeddings = 32768; + num_attention_heads = 16; + num_hidden_layers = 24; + num_key_value_heads = 16; + rms_norm_eps = 1e-6; + rope_theta = 1000000.0; + sliding_window = 32768; + vocab_size = 151936; + tie_embedding_words = false; } else { throw std::runtime_error("Unsupported model size"); } diff --git a/src/models/qwen/modeling_qwen.hpp b/src/models/qwen/modeling_qwen.hpp index 8d424890..370d1891 100644 --- a/src/models/qwen/modeling_qwen.hpp +++ b/src/models/qwen/modeling_qwen.hpp @@ -184,6 +184,8 @@ class QWenForCausalLM final : public Module { // Others use nn.Linear() if (tie_embedding_words) { lm_head = Parameter(1, config.vocab_size, 1, config.hidden_size, names.token_embd_name + ".weight"); + } else{ + lm_head_layer = Linear(config.hidden_size, config.vocab_size, false, names.lm_head_name); } } @@ -194,6 +196,8 @@ class QWenForCausalLM final : public Module { auto outputs = model({x})[0]; if (tie_embedding_words) { outputs = Tensor::mm(outputs, lm_head().transpose(Chl::SEQUENCE, Chl::DIMENSION)); + } else { + outputs = lm_head_layer(outputs); } return {outputs}; } @@ -203,6 +207,7 @@ class QWenForCausalLM final : public Module { bool tie_embedding_words; Layer embedding; Parameter lm_head; + Layer lm_head_layer; QWenModel model; }; diff --git a/src/quantizer/QuantWriter.cpp b/src/quantizer/QuantWriter.cpp index 76ace46c..6383861a 100644 --- a/src/quantizer/QuantWriter.cpp +++ b/src/quantizer/QuantWriter.cpp @@ -40,6 +40,8 @@ vector fp32_layers = {"norm", "rope", "bias","rotary_emb", "embed_tokens "modality_preprocessors", "modality_heads", "modality_postprocessors", "pre_transformer_layer"}; vector q6_layers = {"w2", "wv", "dense_h_to_4h", "v_proj", "down_proj"}; +int tmp_hidden_dim = -1; + bool find_names(const string &name, const vector &layer_names) { for (const auto &layer : layer_names) { if (name.find(layer) != std::string::npos) { @@ -58,8 +60,21 @@ void QuantWriter::quantParams(DataType dataType) { __exit(-1); } auto size = param_loader_->offsets_[name].second / sizeof(float); + if(find_names(name, {"input_layernorm"})) { + tmp_hidden_dim = size; + } void *quant_ptr = nullptr; std::pair block_t; + if (find_names(name, q6_layers)) { + if(tmp_hidden_dim>0 && (size/tmp_hidden_dim)%256!=0){ + std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_F32) << "\t"; + const auto s = param_loader_->offsets_[name].second / sizeof(float); + const auto tsize = alloc_quant_block(s, MLLM_TYPE_F32).second; + writeParam(name, MLLM_TYPE_F32, param, tsize); + std::cout << " size:" << tsize << std::endl; + continue; + } + } if(find_names(name, fp32_layers)) { std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_F32) << "\t"; const auto s = param_loader_->offsets_[name].second / sizeof(float);