diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 78d60985..5f8c4af2 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -27,6 +27,7 @@ macro(func_llm_add_executable target) ${DIR_SRC} ${PROJECT_SOURCE_DIR}/src/tokenizers/Tokenizer.cpp ${PROJECT_SOURCE_DIR}/src/tokenizers/BPE/Bpe.cpp + ${PROJECT_SOURCE_DIR}/src/tokenizers/WordPiece/WordPiece.cpp ${PROJECT_SOURCE_DIR}/src/tokenizers/Unicode.cpp ${PROJECT_SOURCE_DIR}/src/tokenizers/UnicodeData.cpp ${PROJECT_SOURCE_DIR}/src/processor/PreProcess.cpp @@ -74,8 +75,10 @@ func_llm_add_executable(demo_minicpm) func_llm_add_executable(demo_smollm) func_llm_add_executable(demo_openelm) func_llm_add_executable(demo_dclm) +func_llm_add_executable(demo_bert) func_llm_add_executable(demo_phonelm) + func_vlm_add_executable(demo_llava) func_vlm_add_executable(demo_fuyu) func_vlm_add_executable(demo_vit) diff --git a/examples/demo_bert.cpp b/examples/demo_bert.cpp new file mode 100644 index 00000000..dd4397bd --- /dev/null +++ b/examples/demo_bert.cpp @@ -0,0 +1,40 @@ +// +// Created by xwk on 24-10-23. +// +#include "models/bert/configuration_bert.hpp" +#include "models/bert/modeling_bert.hpp" +#include "models/bert/tokenization_bert.hpp" +#include "cmdline.h" + +/* + * an intent to support gte-small BertModel to do text embedding + * current implementation is just a very basic example with a simple WordPiece tokenizer and a simple BertModel + * not support batch embedding + * */ + +int main(int argc, char *argv[]) { + cmdline::parser cmdParser; + cmdParser.add("model", 'm', "specify mllm model path", false, "../models/gte-small-fp32.mllm"); + cmdParser.add("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/gte_vocab.mllm"); + cmdParser.add("thread", 't', "num of threads", false, 4); + cmdParser.parse_check(argc, argv); + + string model_path = cmdParser.get("model"); + string vocab_path = cmdParser.get("vocab"); + CPUBackend::cpu_threads = cmdParser.get("thread"); + + BertTokenizer tokenizer(vocab_path, true); + string text = "Help me set an alarm at 21:30"; + auto [token_ids, type_ids, position_ids] = tokenizer.process(text); + // token_ids.printData(); + + auto config = BertConfig(); + auto model = BertModel(config); + model.load(model_path); + + auto res = model({token_ids, type_ids, position_ids})[0]; + + res.printData(); + + return 0; +} diff --git a/examples/demo_qwen.cpp b/examples/demo_qwen.cpp index 49acdcf0..7e14e410 100644 --- a/examples/demo_qwen.cpp +++ b/examples/demo_qwen.cpp @@ -65,4 +65,4 @@ int main(int argc, char **argv) { }); std::cout << "\n"; } -} \ No newline at end of file +} diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 34b50bbc..493e6c4e 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -376,6 +376,8 @@ class Tensor { return shape[3] * shape[2]; } else if (master_tensor_->master_tensor_->ctype_ == BHDS) { return shape[3]; + } else if (master_tensor_->master_tensor_->ctype_ == BDHS) { + return shape[3] * shape_[2]; } else { std::cout << "sequenceSkipDim() only support for BSHD and BHDS" << std::endl; return -1; @@ -386,6 +388,8 @@ class Tensor { return shape[3] * shape[2]; } else if (master_tensor_->ctype_ == BHDS) { return shape[3]; + } else if (master_tensor_->ctype_ == BDHS) { + return shape[3] * shape_[2]; } else { std::cout << "sequenceSkipDim() only support for BSHD and BHDS" << std::endl; return -1; diff --git a/src/backends/cpu/compute/Matmul.cpp b/src/backends/cpu/compute/Matmul.cpp index 773d1ea2..b434d043 100644 --- a/src/backends/cpu/compute/Matmul.cpp +++ b/src/backends/cpu/compute/Matmul.cpp @@ -198,12 +198,11 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te auto src0_blck_size = blck_size(src0->dtype()); #ifdef LLAMAFILE_SGEMM - if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), - dst->dtype()) + int ld_src1 = src1->sequenceSkipDim(); + int ld_src0 = src0->sequenceSkipDim(); + int ld_dst = dst->sequenceSkipDim(); + if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype(), ld_src1 / src1_blck_size, ld_src0 / src0_blck_size, ld_dst / blck_size(dst->dtype())) && dst->aggregatedTensors().empty()) { - const int ld_src1 = src1->sequenceSkipDim(); - const int ld_src0 = src0->sequenceSkipDim(); - const int ld_dst = dst->sequenceSkipDim(); int is_0 = (src1->batch() == 1 && src1->head() == 1 && src1->batch() != src0->batch()) ? 0 : 1; #pragma omp parallel for collapse(3) num_threads(thread_count) @@ -277,13 +276,13 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te } #ifdef LLAMAFILE_SGEMM + ld_src1 = src1->sequenceSkipDim(); + ld_src0 = src0->sequenceSkipDim(); + ld_dst = dst->sequenceSkipDim(); if (check_llamafile_sgemm(N, M, K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), - dst->dtype()) + dst->dtype(), ld_src1 / src1_blck_size, ld_src0 / src0_blck_size, ld_dst / blck_size(dst->dtype())) && dst->dtypeAt(0, 0, 0, 0) == MLLM_TYPE_F32 && dst->ctype() == BSHD && dst->aggregatedTensors().empty()) { - const int ld_src1 = src1->sequenceSkipDim(); - const int ld_src0 = src0->sequenceSkipDim(); - const int ld_dst = dst->sequenceSkipDim(); #pragma omp parallel for collapse(3) num_threads(thread_count) for (int64_t b = 0; b < dst->batch(); b++) { for (int64_t h = 0; h < dst->head(); h++) { @@ -425,278 +424,6 @@ ErrorCode mat_mul(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Te } return MLLM_NO_ERROR; } -/* -ErrorCode mat_mul_fp32(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, -bool transpose0, bool transpose1, int thread_count) { const int M = transpose0 ? src0->dimension() : -src0->sequence(); const int K = transpose0 ? src0->sequence() : src0->dimension(); const int N = -transpose1 ? src1->sequence() : src1->dimension(); Tensor *src0_cal = src0; Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; -#pragma omp parallel for collapse(4) num_threads(thread_count) - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - for (int m = 0; m < M; m++) { - for (int block = 0; block < N / blck_0 + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { - int s_1, d_1; - int s_0, d_0; - if (!transpose0 && transpose1) { - s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; - } else if (!transpose0 && !transpose1) { - s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; - } else { - s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; - } - if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { - vec_dot_fp32(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b*is_0, -h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if (support_bias) -{ *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - }else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_fp32(K, &tmp, - src1_cal->hostPtr() + src1_cal->offset(b*is_0, -h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if (support_bias) -{ *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + bias->dataAt(0, 0, 0, n)); - } else { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp); - } - }else{std::cout<<"Not support type [Matmul]"<dtype() == MLLM_TYPE_F16); - assert(src0_->dtype() == MLLM_TYPE_F32); - Tensor src0_qf16(src0_->shape()); - src0_qf16.setBackend(src0_->backend()); - src0_qf16.setDtype(MLLM_TYPE_F16); - src0_qf16.alloc(); -#pragma omp parallel for collapse(3) num_threads(thread_count) - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { - for (int s = 0; s < src0_->sequence(); s++) { - mllm_fp32_to_fp16_row(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_qf16.hostPtr() + src0_qf16.offset(b, h, s, -0), src0_->dimension()); - } - } - } - auto *src0 = &src0_qf16; - const int M = transpose0 ? src0->dimension() : src0->sequence(); - const int K = transpose0 ? src0->sequence() : src0->dimension(); - const int N = transpose1 ? src1->sequence() : src1->dimension(); - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; -#pragma omp parallel for collapse(4) num_threads(thread_count) - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - for (int m = 0; m < M; m++) { - for (int block = 0; block < N / blck_0 + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { - int s_1, d_1; - int s_0, d_0; - if (!transpose0 && transpose1) { - s_1 = n; d_1 = 0; s_0 = m; d_0 = 0; - } else if (!transpose0 && !transpose1) { - s_1 = 0; d_1 = n; s_0 = m; d_0 = 0; - } else { - s_1 = 0; d_1 = n; s_0 = 0; d_0 = m; - } - vec_dot_fp16(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b*is_0, -h*is_0, s_1, d_1), src0_cal->hostPtr() + src0_cal->offset(b, h, s_0, d_0)); if -(support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } - } - } - } - } - return MLLM_NO_ERROR; -} - -ErrorCode mat_mul_fp32_q4_0(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor -*bias, int thread_count) { assert(src1->dtype() == MLLM_TYPE_Q4_0); assert(src0_->dtype() == -MLLM_TYPE_F32); Tensor src0_q8(src0_->shape()); src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_0); - src0_q8.alloc(); - if (src0_->dimension() % QK8_0 == 0) { -#pragma omp parallel for collapse(3) num_threads(thread_count) - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_0(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / -QK8_0, src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK8_0 << "!=0" << std::endl; - assert(src0_->dimension() % QK8_0 == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_0); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; -#pragma omp parallel for collapse(4) num_threads(thread_count) - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - for (int m = 0; m < M; m++) { - for (int block = 0; block < N / blck_0 + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { - vec_dot_q4_0_q8_0(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + src1_cal->offset(b*is_0, -h*is_0, n, 0) / QK4_0, src0_cal->hostPtr() + src0_cal->offset(b, h, m, 0) / QK8_0); if -(support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, n); - } - } - } - } - } - } - return MLLM_NO_ERROR; -} - -ErrorCode mat_mul_fp32_q4_K(Tensor *src0_, Tensor *src1, Tensor *dst, bool support_bias, Tensor -*bias, int thread_count) { assert(src1->dtype() == MLLM_TYPE_Q4_K); assert(src0_->dtype() == -MLLM_TYPE_F32); Tensor src0_q8(src0_->shape()); src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_K); - src0_q8.alloc(); - if (src0_->dimension() % QK_K == 0) { -#pragma omp parallel for collapse(3) num_threads(thread_count) - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / -QK_K, src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK_K << "!=0" << std::endl; - assert(src0_->dimension() % QK_K == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_K); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; -#pragma omp parallel for collapse(4) num_threads(thread_count) - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - for (int m = 0; m < M; m++) { - for (int block = 0; block < N / blck_0 + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { - if(dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F32) { - vec_dot_q4_K_q8_K(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + -src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, -h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, -n); - } - } else if (dst->dtypeAt(b,h,m,n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_q4_K_q8_K(K, &tmp, - src1_cal->hostPtr() + -src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, -h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + -bias->dataAt(0, 0, 0, n)); } else { *dst->ptrAt(b, h, m, n) = -MLLM_FP32_TO_FP16(tmp); - } - }else{std::cout<<"Not support type [Matmul]"<dtype() == MLLM_TYPE_Q6_K); assert(src0_->dtype() == -MLLM_TYPE_F32); Tensor src0_q8(src0_->shape()); src0_q8.setBackend(src0_->backend()); - src0_q8.setDtype(MLLM_TYPE_Q8_K); - src0_q8.alloc(); - if (src0_->dimension() % QK_K == 0) { -#pragma omp parallel for collapse(3) num_threads(thread_count) - for (int b = 0; b < src0_->batch(); b++) { - for (int h = 0; h < src0_->head(); h++) { - for (int s = 0; s < src0_->sequence(); s++) { - quantize_row_q8_K(src0_->hostPtr() + src0_->offset(b, h, s, 0), - src0_q8.hostPtr() + src0_q8.offset(b, h, s, 0) / -QK_K, src0_->dimension()); - } - } - } - } else { - std::cout << "[ERROR]: " << src0_->dimension() << "%" << QK_K << "!=0" << std::endl; - assert(src0_->dimension() % QK_K == 0); - } - auto *src0 = &src0_q8; - assert(src0->dtype() == MLLM_TYPE_Q8_K); - int M = src0->sequence(); - int K = src0->dimension(); - int N = src1->sequence(); - Tensor *src0_cal = src0; - Tensor *src1_cal = src1; - const int64_t blck_0 = 16; - int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; -#pragma omp parallel for collapse(4) num_threads(thread_count) - for (int b = 0; b < src0->batch(); b++) { - for (int h = 0; h < src0->head(); h++) { - for (int m = 0; m < M; m++) { - for (int block = 0; block < N / blck_0 + 1; block++) { - for (int n = block * blck_0; n < (block + 1) * blck_0 & n < N; n++) { - if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F32) { - vec_dot_q6_K_q8_K(K, dst->ptrAt(b, h, m, n), - src1_cal->hostPtr() + -src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, -h, m, 0) / QK_K); if (support_bias) { *dst->ptrAt(b, h, m, n) += bias->dataAt(0, 0, 0, -n); - } - } else if (dst->dtypeAt(n, h, m, n) == MLLM_TYPE_F16) { - float tmp = 0; - vec_dot_q6_K_q8_K(K, &tmp, - src1_cal->hostPtr() + -src1_cal->offset(b*is_0, h*is_0, n, 0) / QK_K, src0_cal->hostPtr() + src0_cal->offset(b, -h, m, 0) / QK_K); - - if (support_bias) { - *dst->ptrAt(b, h, m, n) = MLLM_FP32_TO_FP16(tmp + -bias->dataAt(0, 0, 0, n)); } else { *dst->ptrAt(b, h, m, n) = -MLLM_FP32_TO_FP16(tmp); - } - } else { - std::cout << "Not support tupe [Matmul]" << std::endl; - } - } - } - } - } - } - return MLLM_NO_ERROR; -} -*/ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_bias, Tensor *bias, int activate_input_dim, int activate_output_dim, bool transpose0, @@ -725,12 +452,12 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_ int use_N = (activate_output_dim == -1) ? N : activate_output_dim; int use_K = (activate_input_dim == -1) ? K : activate_input_dim; - if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src0->dtype()), src1->dtype(), - src0->dtype(), dst->dtype()) +#ifdef LLAMAFILE_SGEMM + int ld_src1 = src1->sequenceSkipDim(); + int ld_src0 = src0->sequenceSkipDim(); + int ld_dst = dst->sequenceSkipDim(); + if (check_llamafile_sgemm(N, M, K / blck_size(src0->dtype()), src1->dtype(), src0->dtype(), dst->dtype(), ld_src1 / src1_blck_size, ld_src0 / src0_blck_size, ld_dst / blck_size(dst->dtype())) && dst->aggregatedTensors().empty()) { - const int ld_src1 = src1->sequenceSkipDim(); - const int ld_src0 = src0->sequenceSkipDim(); - const int ld_dst = dst->sequenceSkipDim(); int is_0 = (src1->batch() == 1 && src1->head() == 1) ? 0 : 1; #pragma omp parallel for collapse(3) num_threads(thread_count) for (int64_t b = 0; b < dst->batch(); b++) { @@ -754,7 +481,7 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_ } return MLLM_NO_ERROR; } - +#endif auto not_vec_dot_type = src0_dtype != vec_dot_type; std::unique_ptr to; // later this tensor will be freed by ~Tensor if (not_vec_dot_type) { @@ -800,12 +527,12 @@ ErrorCode mat_mul_elastic(Tensor *src0, Tensor *src1, Tensor *dst, bool support_ } #ifdef LLAMAFILE_SGEMM - if (check_llamafile_sgemm(use_N, M, use_K / blck_size(src1->dtype()), src1->dtype(), - src0->dtype(), dst->dtype()) + ld_src1 = src1->sequenceSkipDim(); + ld_src0 = src0->sequenceSkipDim(); + ld_dst = dst->sequenceSkipDim(); + if (check_llamafile_sgemm(N, M, K / blck_size(src1->dtype()), src1->dtype(), src0->dtype(), + dst->dtype(), ld_src1 / src1_blck_size, ld_src0 / src0_blck_size, ld_dst / blck_size(dst->dtype())) && !support_bias && dst->ctype() == BSHD && dst->aggregatedTensors().empty()) { - const int ld_src1 = src1->sequenceSkipDim(); - const int ld_src0 = src0->sequenceSkipDim(); - const int ld_dst = dst->sequenceSkipDim(); #pragma omp parallel for collapse(3) num_threads(thread_count) for (int64_t b = 0; b < dst->batch(); b++) { for (int64_t h = 0; h < dst->head(); h++) { diff --git a/src/backends/cpu/compute/SGEMM.cpp b/src/backends/cpu/compute/SGEMM.cpp index 54c80c26..0cbcb438 100644 --- a/src/backends/cpu/compute/SGEMM.cpp +++ b/src/backends/cpu/compute/SGEMM.cpp @@ -1092,7 +1092,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)Ctype; } -bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataType Atype, DataType Btype, DataType Ctype) { +bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataType Atype, DataType Btype, DataType Ctype, int64_t lda, int64_t ldb, int64_t ldc) { int ith = 0; int nth = 1; assert(m >= 0); @@ -1101,6 +1101,13 @@ bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataType Atype, Data assert(nth > 0); assert(ith < nth); + if (lda < k) + return false; + if (ldb < k) + return false; + if (ldc < m) + return false; + if (Ctype != MLLM_TYPE_F32) return false; diff --git a/src/backends/cpu/compute/SGEMM.hpp b/src/backends/cpu/compute/SGEMM.hpp index 9588a8ea..710199b1 100644 --- a/src/backends/cpu/compute/SGEMM.hpp +++ b/src/backends/cpu/compute/SGEMM.hpp @@ -12,7 +12,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda int ith, int nth, DataType Atype, DataType Btype, DataType Ctype, void *bias = nullptr, DataType BiasType = DataType::MLLM_TYPE_F32); -bool check_llamafile_sgemm(int64_t, int64_t, int64_t, - DataType, DataType, DataType); +bool check_llamafile_sgemm(int64_t m, int64_t n, int64_t k, DataType Atype, DataType Btype, DataType Ctype, int64_t lda, int64_t ldb, int64_t ldc); #endif // MLLM_GEMM_HPP diff --git a/src/models/bert/configuration_bert.hpp b/src/models/bert/configuration_bert.hpp new file mode 100644 index 00000000..74e4164d --- /dev/null +++ b/src/models/bert/configuration_bert.hpp @@ -0,0 +1,63 @@ +#ifndef CONFIG_BERT_HPP +#define CONFIG_BERT_HPP +#include "Types.hpp" +#include "models/transformer/configuration_transformer.hpp" +#include +#include + +using namespace mllm; + +class BertNameConfig : public TransformerNameConfig { +public: + void init() { + embedding_base_name = "embeddings."; + + blk_name = "encoder.layer."; + _attn_base_name = "attention."; + _q_proj_name = "self.query"; + _k_proj_name = "self.key"; + _v_proj_name = "self.value"; + _o_proj_name = "output.dense"; + _up_proj_name = "intermediate.dense"; + _down_proj_name = "output.dense"; + _attn_norm_name = "output.LayerNorm"; + _ffn_norm_name = "output.LayerNorm"; + } + std::string embedding_base_name; + + std::string blk_name; +}; + +struct BertConfig : public TransformerConfig { + explicit BertConfig() { + hidden_act = "GELU"; + pooling_type = "mean"; + hidden_size = 384; + intermediate_size = 1536; + max_position_embeddings = 512; + model_type = "bert"; + num_attention_heads = 12; + num_hidden_layers = 12; + vocab_size = 30522; + names_config.init(); + }; + + int type_vocab_size = 2; + float layer_norm_eps = 1e-12; + + std::string hidden_act = "GELU"; + std::string pooling_type = "mean"; + int hidden_size = 1024; + int intermediate_size = 2816; + int max_position_embeddings = 32768; + std::string model_type = "bert"; + int num_attention_heads = 12; + int num_hidden_layers = 12; + + + int vocab_size = 151936; + + BertNameConfig names_config; +}; + +#endif //! CONFIG_BERT_HPP diff --git a/src/models/bert/modeling_bert.hpp b/src/models/bert/modeling_bert.hpp new file mode 100644 index 00000000..e7d6730f --- /dev/null +++ b/src/models/bert/modeling_bert.hpp @@ -0,0 +1,117 @@ +#ifndef MODELING_BERT_HPP +#define MODELING_BERT_HPP + +#include "Backend.hpp" +#include "Layer.hpp" +#include "Module.hpp" +#include "Tensor.hpp" +#include "configuration_bert.hpp" +#include "models/transformer/modeling_transformer.hpp" +using namespace mllm; + +class BertEmbeddings : public Module { +public: + BertEmbeddings() = default; + BertEmbeddings(int vocal_size, int hidden_size, int type_size, int max_position_embeddings, float eps, BertNameConfig &config) { + word_embeddings = Embedding(vocal_size, hidden_size, config.embedding_base_name + "word_embeddings"); + token_type_embeddings = Embedding(type_size, hidden_size, config.embedding_base_name + "token_type_embeddings"); + position_embeddings = Embedding(max_position_embeddings, hidden_size, config.embedding_base_name + "position_embeddings"); + layer_norm = LayerNorm(hidden_size, true, eps, config.embedding_base_name + "LayerNorm"); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto inputs_embeds = word_embeddings(inputs[0]); + auto type_embeds = token_type_embeddings(inputs[1]); + auto position_embeds = position_embeddings(inputs[2]); + auto embeddings = inputs_embeds + type_embeds + position_embeds; + return {layer_norm(embeddings)}; + } + +private: + Layer word_embeddings; + Layer token_type_embeddings; + Layer position_embeddings; + Layer layer_norm; +}; + +class BertLayer : public Module { +public: + BertLayer() = default; + BertLayer(const BertConfig &config, const string &base_name) { + // base_name: encoder.layer.n. + attention = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_attention_heads, + config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, RoPEType::NONE, -1, -1, 0, false, true, config.names_config, + base_name + config.names_config._attn_base_name); + + feed_forward = FeedForward(config.hidden_size, config.intermediate_size, + config.hidden_act, true, config.names_config, base_name); + + attn_norm = LayerNorm(config.hidden_size, true, config.layer_norm_eps, + base_name + config.names_config._attn_base_name + config.names_config._attn_norm_name); + + ff_norm = LayerNorm(config.hidden_size, true, config.layer_norm_eps, + base_name + config.names_config._ffn_norm_name); + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto hidden_states = inputs[0]; + + auto attn_out = attention({hidden_states, hidden_states, hidden_states})[0]; + + hidden_states = attn_norm({hidden_states + attn_out}); + + auto ff_out = feed_forward({hidden_states})[0]; + + hidden_states = ff_norm({hidden_states + ff_out}); + + return {hidden_states}; + } + +private: + MultiHeadAttention attention; + FeedForward feed_forward; + + Layer attn_norm, ff_norm; +}; + +class BertAvgPooler : public Module { +public: + BertAvgPooler() = default; + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = inputs[0]; + x = x.mean(SEQUENCE); + return {x}; + } +}; + +class BertModel : public Module { +public: + BertModel(BertConfig &config) { + embeddings = BertEmbeddings(config.vocab_size, config.hidden_size, config.type_vocab_size, config.max_position_embeddings, config.layer_norm_eps, config.names_config); + layers = List(config.num_hidden_layers, config, config.names_config.blk_name); + + if (config.pooling_type == "mean") { + pooler = BertAvgPooler(); + } else { + // print not support pooling type and exit + std::cout << "Not support pooling type: " << config.pooling_type << std::endl; + exit(0); + } + } + + std::vector Forward(std::vector inputs, std::vector args) override { + auto x = embeddings(inputs, args)[0]; + for (auto &layer : layers) { + x = layer({x})[0]; + } + x = pooler({x})[0]; + return {x}; + } + +private: + BertEmbeddings embeddings; + std::vector layers; + BertAvgPooler pooler; +}; + +#endif //! MODELING_BERT_HPP \ No newline at end of file diff --git a/src/models/bert/tokenization_bert.hpp b/src/models/bert/tokenization_bert.hpp new file mode 100644 index 00000000..d4d03809 --- /dev/null +++ b/src/models/bert/tokenization_bert.hpp @@ -0,0 +1,52 @@ +#ifndef TOKENIZATION_BERT_HPP +#define TOKENIZATION_BERT_HPP + +#include "tokenizers/BPE/Bpe.hpp" +#include "tokenizers/Tokenizer.hpp" +#include "tokenizers/Unicode.hpp" +#include "tokenizers/WordPiece/WordPiece.hpp" +#include +#include + +// unicode +#include + +using namespace mllm; + + +class BertTokenizer final : public WordPieceTokenizer { +public: + explicit BertTokenizer(const std::string &vocab_file, bool add_special_tokens = true) : + WordPieceTokenizer(vocab_file) { + Module::initBackend(MLLM_CPU); + _add_special_tokens = add_special_tokens; + this->add_special_tokens({"[PAD]", "[CLS]", "[SEP]", "[MASK]"}); + } + std::tuple process(std::string text){ + if (_add_special_tokens) { + text = "[CLS] " + text + " [SEP]"; + } + auto tokens_id = vector(); + WordPieceTokenizer::tokenize(text, tokens_id, false); +// printf("token: "); +// for (auto &token_id : tokens_id) { +// printf("%d ", token_id); +// } + printf("\n"); + auto tokens_type = vector(tokens_id.size(), 0); + auto position_ids = vector(tokens_id.size()); + for (size_t i = 0; i < tokens_id.size(); i++) { + position_ids[i] = i; + } + return { + tokens2Input(tokens_id, "input_tokens"), + tokens2Input(tokens_type, "input_tokens_type"), + tokens2Input(position_ids, "input_position_ids") + }; + } + +private: + bool _add_special_tokens; +}; + +#endif //! TOKENIZATION_BERT_HPP \ No newline at end of file diff --git a/src/tokenizers/WordPiece/WordPiece.cpp b/src/tokenizers/WordPiece/WordPiece.cpp new file mode 100644 index 00000000..844a0c5a --- /dev/null +++ b/src/tokenizers/WordPiece/WordPiece.cpp @@ -0,0 +1,230 @@ +// +// Created by xwk on 24-10-26. +// + +#include "WordPiece.hpp" + +std::wstring utf8_to_wstring(const std::string& utf8str) { + std::wstring_convert> converter; + return converter.from_bytes(utf8str); +} + +std::string wstring_to_utf8(const std::wstring& wstr) { + std::wstring_convert> converter; + return converter.to_bytes(wstr); +} + +std::wstring trim(const std::wstring &s) { + auto wsfront = std::find_if_not(s.begin(), s.end(), [](int c) { return std::isspace(c); }); + auto wsback = std::find_if_not(s.rbegin(), s.rend(), [](int c) { return std::isspace(c); }).base(); + return (wsback <= wsfront ? std::wstring() : std::wstring(wsfront, wsback)); +} + +std::vector split(const std::wstring &s) { + std::wistringstream iss(s); + std::vector tokens; + std::wstring token; + while (iss >> token) { + tokens.push_back(token); + } + return tokens; +} + +std::vector whitespace_tokenize(const std::wstring &text) { + std::wstring trimmed_text = trim(text); + if (trimmed_text.empty()) { + return {}; + } + return split(trimmed_text); +} + +mllm::WordPieceTokenizer::WordPieceTokenizer(const string &vocab_file) : + Tokenizer(vocab_file), basic_tokenizer(true) { +} + +bool mllm::BasicTokenizer::is_punctuation(wchar_t ch) { + // Simplified check assuming punctuation falls within these ranges + return std::ispunct(ch) != 0; +} + +std::wstring mllm::BasicTokenizer::clean_text(const std::wstring& text) { + std::wstring output; + for (wchar_t c : text) { + if (c == 0 || c == 0xFFFD || std::iscntrl(c)) continue; + if (std::iswspace(c)) + output += L' '; + else + output += c; + } + return output; +} + +std::wstring mllm::BasicTokenizer::strip_accents_from_text(const std::wstring& input) { + // This function needs proper implementation depending on the locale + return input; // Simplified as placeholder +} + +std::vector mllm::BasicTokenizer::split_on_punctuation(const std::wstring& text) { + std::vector result; + std::wstring token; + for (wchar_t ch : text) { + if (is_punctuation(ch)) { + if (!token.empty()) { + result.push_back(token); + token.clear(); + } + result.push_back(std::wstring(1, ch)); + } else { + token += ch; + } + } + if (!token.empty()) + result.push_back(token); + return result; +} + +std::wstring mllm::BasicTokenizer::tokenize_chinese_chars(const std::wstring& text) { + std::wstring output; + for (wchar_t ch : text) { + if (is_chinese_char(ch)) { + output += L' '; + output += ch; + output += L' '; + } else { + output += ch; + } + } + return output; +} + +bool mllm::BasicTokenizer::is_chinese_char(wchar_t cp) { + // Simplified check for Chinese characters range + return (cp >= 0x4E00 && cp <= 0x9FFF) || + (cp >= 0x3400 && cp <= 0x4DBF) || + (cp >= 0x20000 && cp <= 0x2A6DF); +} + +std::vector splitBySet(const std::wstring& text, const std::unordered_set& words) { + std::vector result; + size_t pos = 0; + + while (pos < text.length()) { + size_t minPos = std::wstring::npos; + std::wstring foundWord; + + // 查找最近的匹配项 + for (const auto& word : words) { + size_t found = text.find(word, pos); + if (found != std::wstring::npos && (found < minPos)) { + minPos = found; + foundWord = word; + } + } + + // 如果找到匹配项,处理之前的文本和匹配项 + if (minPos != std::wstring::npos) { + if (minPos > pos) { + // 添加匹配项前的文本 + result.push_back(text.substr(pos, minPos - pos)); + } + // 添加匹配项 + result.push_back(foundWord); + pos = minPos + foundWord.size(); + } else { + // 没有更多匹配项,添加剩余所有文本 + result.push_back(text.substr(pos)); + break; + } + } + + return result; +} + +std::vector mllm::BasicTokenizer::tokenize(const std::wstring& text) { + std::wstring cleaned = clean_text(text); + if (_tokenize_chinese_chars) + cleaned = tokenize_chinese_chars(cleaned); + std::vector white_space_splited_tokens = whitespace_tokenize(cleaned); + std::vector split_tokens; + for (const auto& token : white_space_splited_tokens) { + auto sub_tokens = splitBySet(token, never_split); + split_tokens.insert(split_tokens.end(), sub_tokens.begin(), sub_tokens.end()); + } + std::vector output_tokens; + + for (auto& token : split_tokens) { +// std::cout << "token: " << wstring_to_utf8(token) << std::endl; + if (never_split.count(token)) { + output_tokens.push_back(token); + }else{ + if (do_lower_case) { + std::transform(token.begin(), token.end(), token.begin(), + [](wchar_t c){ return std::towlower(c); }); + } + + if (strip_accents) + token = strip_accents_from_text(token); + + // split on punctuation and add to output + auto sub_tokens = split_on_punctuation(token); +// std::cout << "token: " << wstring_to_utf8(token) << std::endl; + for (auto& sub_token : sub_tokens) { + if (!sub_token.empty()) { +// std::cout << "sub_token: " << wstring_to_utf8(sub_token) << std::endl; + output_tokens.push_back(sub_token); + } + } + } + } + + return output_tokens; +} +void mllm::BasicTokenizer::add_never_split(const std::wstring &token) { + never_split.insert(token); +} + +void mllm::WordPieceTokenizer::tokenize(const string &text, vector &tokens, bool bos) { + auto wstr = utf8_to_wstring(text); + auto basic_tokens = basic_tokenizer.tokenize(wstr); + auto token_strs = vector (); + auto token_ids = vector(); + for (const auto& token : basic_tokens) { + int start = 0; + while(start < token.size()) { + auto end = token.size(); + string str; + while(start < end){ + auto sub_str = token.substr(start, end - start); + if (start > 0) + sub_str = L"##" + sub_str; + auto utf8_str = wstring_to_utf8(sub_str); +// std::cout << "utf8_str: " << utf8_str << std::endl; + if (vocab_map_.count(utf8_str)){ + str = utf8_str; + break; + }else{ + end--; + } + } + if (str.empty()){ + token_strs.push_back("[UNK]"); + break; + } else{ + token_strs.push_back(str); +// printf("word: %s\n", str.c_str()); + } + start = end; + } + } + + for (const auto& token_str : token_strs) { +// std::cout << "token: " << token_str << std::endl; + tokens.push_back(vocab_map_[token_str]); + } +} +void mllm::WordPieceTokenizer::add_special_tokens(const vector &special_tokens) { + // add never split tokens to basic tokenizer + for (const auto& token : special_tokens) { + basic_tokenizer.add_never_split(utf8_to_wstring(token)); + } +} diff --git a/src/tokenizers/WordPiece/WordPiece.hpp b/src/tokenizers/WordPiece/WordPiece.hpp new file mode 100644 index 00000000..4bbee119 --- /dev/null +++ b/src/tokenizers/WordPiece/WordPiece.hpp @@ -0,0 +1,66 @@ +// +// Created by xwk on 24-10-26. +// + +#ifndef MLLM_WORDPIECE_HPP +#define MLLM_WORDPIECE_HPP + +#include "tokenizers/Tokenizer.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mllm{ + +class BasicTokenizer { +public: + BasicTokenizer(bool do_lower_case = true, + std::vector const& never_split = {}, + bool tokenize_chinese_chars = true, + bool strip_accents = true, + bool do_split_on_punc = true) + : do_lower_case(do_lower_case), + _tokenize_chinese_chars(tokenize_chinese_chars), + strip_accents(strip_accents), + do_split_on_punc(do_split_on_punc), + never_split(never_split.begin(), never_split.end()) {} + + std::vector tokenize(const std::wstring& text); + + void add_never_split(const std::wstring& token); + +private: + bool do_lower_case; + bool _tokenize_chinese_chars; + bool strip_accents; + bool do_split_on_punc; + std::unordered_set never_split; + + std::wstring clean_text(const std::wstring& text); + std::wstring strip_accents_from_text(const std::wstring& input); + std::vector split_on_punctuation(const std::wstring& text); + std::wstring tokenize_chinese_chars(const std::wstring& text); + bool is_chinese_char(wchar_t cp); + static bool is_punctuation(wchar_t ch); +}; + + +class WordPieceTokenizer: public Tokenizer { +public: + BasicTokenizer basic_tokenizer; + + WordPieceTokenizer(const std::string &vocab_file); + void tokenize(const std::string &text, std::vector &tokens, bool bos) override; + + void add_special_tokens(const std::vector &special_tokens); +}; + +} + +#endif // MLLM_WORDPIECE_HPP diff --git a/vocab/gte_vocab.mllm b/vocab/gte_vocab.mllm new file mode 100644 index 00000000..bc4dbc64 Binary files /dev/null and b/vocab/gte_vocab.mllm differ