From b5585b6559f4f4f163497d870ea0f6a6ce329a18 Mon Sep 17 00:00:00 2001 From: Lee <44310445+lx200916@users.noreply.github.com> Date: Tue, 31 Oct 2023 23:18:47 +0800 Subject: [PATCH 1/3] feat: Add Quantizer. Signed-off-by: Lee <44310445+lx200916@users.noreply.github.com> --- CMakeLists.txt | 26 ++++++++++++++++++- src/ParamLoader.cpp | 8 ++++++ src/ParamLoader.hpp | 5 +++- src/quantizer/ParamWriter.cpp | 47 +++++++++++++++++++++++++++++++++++ src/quantizer/ParamWriter.hpp | 45 +++++++++++++++++++++++++++++++++ src/quantizer/main.cpp | 32 ++++++++++++++++++++++++ 6 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 src/quantizer/ParamWriter.cpp create mode 100644 src/quantizer/ParamWriter.hpp create mode 100644 src/quantizer/main.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e542eaba..fd15a036 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # temp executable name set(TEST_EXE main_test) option(TEST "test mode" ON) +option(QUANT "quantize tools" ON) if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") cmake_policy(SET CMP0135 NEW) endif () @@ -42,6 +43,8 @@ aux_source_directory(${PROJECT_SOURCE_DIR}/src/express DIR_SRC_EXP) #aux_source_directory(${PROJECT_SOURCE_DIR}/src/quantize DIR_SRC_QUANT) aux_source_directory(${PROJECT_SOURCE_DIR}/examples EMP_SRC) aux_source_directory(${PROJECT_SOURCE_DIR}/test TEST_SRC) +aux_source_directory(${PROJECT_SOURCE_DIR}/src/quantizer QUANT_SRC) + include_directories(${PROJECT_SOURCE_DIR}/src) include_directories(${PROJECT_SOURCE_DIR}/include) @@ -104,4 +107,25 @@ if(NNAPI) add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/src/backends/nnapi) add_executable(nnapi_test ${PROJECT_SOURCE_DIR}/demo/nnapi_test.cpp ${DIR_SRC_CPU} ${DIR_SRC_EXP} ${DIR_SRC} )#${DIR_SRC_QUANT}) target_link_libraries(nnapi_test MLLM_CPU MLLM_NNAPI) -endif() \ No newline at end of file +endif() + +if (QUANT) + include_directories(${PROJECT_SOURCE_DIR}/src/quantizer) + file(GLOB_RECURSE MLLM_QUANT + ${CMAKE_CURRENT_LIST_DIR}/src/quantizer/*.cpp + ${CMAKE_CURRENT_LIST_DIR}/src/quantizer/*.hpp + ${PROJECT_SOURCE_DIR}/src/backends/cpu/quantize/*.hpp + ${PROJECT_SOURCE_DIR}/src/backends/cpu/quantize/*.cpp + ) + + message(STATUS "MLLM_Quant: ${MLLM_QUANT}") + add_executable( + MLLM_QUANT + ${PROJECT_SOURCE_DIR}/src/quantizer/main.cpp + ${MLLM_QUANT} + ${DIR_SRC} + + ) + + +endif () \ No newline at end of file diff --git a/src/ParamLoader.cpp b/src/ParamLoader.cpp index 296525c2..2c0da5b0 100644 --- a/src/ParamLoader.cpp +++ b/src/ParamLoader.cpp @@ -96,4 +96,12 @@ ParamLoader::ParamLoader(std::string filename, bool use_mmap) : bool ParamLoader::load(std::shared_ptr tensor) { return load(tensor.get()); } +vector ParamLoader::getParamNames() { + // get keys of data_type_ + vector keys; + for (auto &iter : data_type_) { + keys.push_back(iter.first); + } + return keys; +} } // namespace mllm \ No newline at end of file diff --git a/src/ParamLoader.hpp b/src/ParamLoader.hpp index 274a832c..fa1cad1d 100644 --- a/src/ParamLoader.hpp +++ b/src/ParamLoader.hpp @@ -12,7 +12,7 @@ namespace mllm { class Tensor; static int readInt(FILE *fp_) { int tmp; - fread(&tmp, sizeof(int), 1, fp_); + fread(&tmp, sizeof(int32_t), 1, fp_); return tmp; } static uint64_t readu64(FILE *fp_) { @@ -44,6 +44,8 @@ static std::string readString(FILE *fp_) { } #define _MAGIC_NUMBER 20012 class ParamLoader { + friend class QuantWriter; + public: ParamLoader(std::string filename, bool use_mmap = false); #ifdef USE_MMAP @@ -52,6 +54,7 @@ class ParamLoader { ~ParamLoader(); bool load(mllm::Tensor *tensor); bool load(std::shared_ptr tensor); + vector getParamNames(); private: FILE *fp_; diff --git a/src/quantizer/ParamWriter.cpp b/src/quantizer/ParamWriter.cpp new file mode 100644 index 00000000..10532c9e --- /dev/null +++ b/src/quantizer/ParamWriter.cpp @@ -0,0 +1,47 @@ +// +// Created by lx on 23-10-30. +// + +#include "ParamWriter.hpp" + +ParamWriter::ParamWriter(std::string filename) : + path_(std::move(filename)) { + fp_ = fopen(path_.c_str(), "wb"); + writeInt(fp_, _MAGIC_NUMBER); +} + +int ParamWriter::calcIndexSize(const vector names) { + int size = 0; + for (const auto &name : names) { + // One Tensor Index Item Contains: Name_Len(Int)+Name(str)+Weights_Len(UInt64)+Offset(UInt64)+DataType(Int) + size += sizeof(int) + name.size() + sizeof(uint64_t) + sizeof(uint64_t) + sizeof(int); + } + return size; +} +void ParamWriter::writeIndex() { + fseek(fp_, sizeof(int32_t) + sizeof(uint64_t), SEEK_SET); + for (const auto ¶m : param_info_) { + writeString(fp_, param.name); + write_u64(fp_, param.size); + write_u64(fp_, param.offset); + writeInt(fp_, param.type); + } +} + +void ParamWriter::writeParam(string name, mllm_dtype type, void *data, uint64_t size) { + auto param = param_info_[index_]; + param.name = std::move(name); + param.type = type; + param.offset = ftell(fp_); + fwrite(data, 1, size, fp_); + param.size = ftell(fp_) - param.offset; + index_++; +} +void ParamWriter::paddingIndex(const vector names) { + param_info_.resize(names.size()); + // write 0 padding to preserve space for index + int index_size = calcIndexSize(names); + write_u64(fp_, index_size); + char i = '\0'; + fwrite(&i, 1, index_size, fp_); +} diff --git a/src/quantizer/ParamWriter.hpp b/src/quantizer/ParamWriter.hpp new file mode 100644 index 00000000..db1a79eb --- /dev/null +++ b/src/quantizer/ParamWriter.hpp @@ -0,0 +1,45 @@ +// +// Created by lx on 23-10-30. +// + +#ifndef MLLM_PARAMWRITER_HPP +#define MLLM_PARAMWRITER_HPP +#include "ParamLoader.hpp" +static void write_u64(FILE *fp, uint64_t val) { + fwrite(&val, sizeof(uint64_t), 1, fp); +} +static void writeInt(FILE *fp, int32_t val) { + fwrite(&val, sizeof(int32_t), 1, fp); +} +static void writeString(FILE *fp, const std::string &str) { + writeInt(fp, str.size()); + fwrite(str.c_str(), str.size(), 1, fp); +} +static void write_dtype(FILE *fp, mllm_dtype dtype) { + writeInt(fp, dtype); +} + +struct ParmInfo { + std::string name; + mllm_dtype type; + uint64_t offset; + uint64_t size; +}; +class ParamWriter { +public: + ParamWriter(std::string filename); + int calcIndexSize(vector names); + void writeIndex(); + void writeParam(string name, mllm_dtype type, void *data, uint64_t size); + +private: + uint64_t index_ = 0; + FILE *fp_; + std::string path_; + std::vector param_info_; + +protected: + void paddingIndex(vector names); +}; + +#endif // MLLM_PARAMWRITER_HPP diff --git a/src/quantizer/main.cpp b/src/quantizer/main.cpp new file mode 100644 index 00000000..465eae06 --- /dev/null +++ b/src/quantizer/main.cpp @@ -0,0 +1,32 @@ +// +// Created by lx on 23-10-31. +// +#include "ParamWriter.hpp" +#include "ParamLoader.hpp" +#include +namespace mllm { +class QuantWriter : public ParamWriter { + explicit QuantWriter(std::string output_path, std::string input_path); + int ReadParams(); + +private: + mllm::ParamLoader *param_loader_; + std::vector param_names_; + float *GetParam(std::string param_name); +}; +QuantWriter::QuantWriter(std::string output_path, std::string input_path) : + ParamWriter(std::move(output_path)) { + param_loader_ = new mllm::ParamLoader(std::move(input_path)); + if (param_loader_ == nullptr) { + exit(-1); + } +} +int QuantWriter::ReadParams() { + param_names_ = param_loader_->getParamNames(); + paddingIndex(param_names_); + return param_names_.size(); +} +float *QuantWriter::GetParam(std::string param_name) { +} + +} // namespace mllm From c67048942f475042588741c96bed1a94eba551c8 Mon Sep 17 00:00:00 2001 From: Lee <44310445+lx200916@users.noreply.github.com> Date: Wed, 1 Nov 2023 00:44:43 +0800 Subject: [PATCH 2/3] fix: Add Quant Types. Signed-off-by: Lee <44310445+lx200916@users.noreply.github.com> --- include/Types.hpp | 3 +++ src/ParamLoader.cpp | 6 ++++++ src/ParamLoader.hpp | 1 + src/quantizer/main.cpp | 41 +++++++++++++++++++++++++++++++++++++++-- 4 files changed, 49 insertions(+), 2 deletions(-) diff --git a/include/Types.hpp b/include/Types.hpp index 9f1383a4..26fe307b 100644 --- a/include/Types.hpp +++ b/include/Types.hpp @@ -150,6 +150,9 @@ struct BackendConfig { enum DataType { FP32 = 0, FP16, + INT8, + INT4, + DATA_TYPE_COUNT, }; } // namespace mllm diff --git a/src/ParamLoader.cpp b/src/ParamLoader.cpp index 2c0da5b0..48fd1fc6 100644 --- a/src/ParamLoader.cpp +++ b/src/ParamLoader.cpp @@ -104,4 +104,10 @@ vector ParamLoader::getParamNames() { } return keys; } +uint8_t *ParamLoader::load(string name) { + std::pair offset = offsets_[name]; + uint8_t *data = new uint8_t[offset.second]; + fseek(fp_, offset.first, SEEK_SET); + fread(data, sizeof(uint8_t), offset.second, fp_); +} } // namespace mllm \ No newline at end of file diff --git a/src/ParamLoader.hpp b/src/ParamLoader.hpp index fa1cad1d..a832f6b6 100644 --- a/src/ParamLoader.hpp +++ b/src/ParamLoader.hpp @@ -63,6 +63,7 @@ class ParamLoader { std::uint64_t size_; std::map> offsets_; // offsets,length std::map data_type_; + uint8_t *load(string name); bool use_mmap_; }; diff --git a/src/quantizer/main.cpp b/src/quantizer/main.cpp index 465eae06..1e807719 100644 --- a/src/quantizer/main.cpp +++ b/src/quantizer/main.cpp @@ -1,6 +1,12 @@ // // Created by lx on 23-10-31. // +#define NOT_IMPLEMENTED(x) \ + std::cout << "Quantize params to " << #x << " is not implemented\n"; \ + exit(-1); +#define UNREACHABLE() \ + std::cout << "Unreachable code\n"; \ + exit(-1); #include "ParamWriter.hpp" #include "ParamLoader.hpp" #include @@ -8,11 +14,13 @@ namespace mllm { class QuantWriter : public ParamWriter { explicit QuantWriter(std::string output_path, std::string input_path); int ReadParams(); + void QuantParams(DataType dataType); private: mllm::ParamLoader *param_loader_; + DataType quant_type_; std::vector param_names_; - float *GetParam(std::string param_name); + float *getParam(std::string param_name); }; QuantWriter::QuantWriter(std::string output_path, std::string input_path) : ParamWriter(std::move(output_path)) { @@ -26,7 +34,36 @@ int QuantWriter::ReadParams() { paddingIndex(param_names_); return param_names_.size(); } -float *QuantWriter::GetParam(std::string param_name) { +float *QuantWriter::getParam(std::string param_name) { + auto type = param_loader_->data_type_[param_name]; + if (type != mllm::DataType::FP32) { + return nullptr; + } + void *data = param_loader_->load(param_name); + return static_cast(data); +} +void QuantWriter::QuantParams(DataType dataType) { + quant_type_ = dataType; + for (const auto &name : param_names_) { + auto *param = getParam(name); + if (param == nullptr) { + exit(-1); + } + switch (dataType) { + case FP32: + std::cout << "No need to quantize FP32 params\n"; + break; + case FP16: + NOT_IMPLEMENTED(FP16); + break; + + case INT8: break; + case INT4: + + break; + case DATA_TYPE_COUNT: UNREACHABLE(); break; + } + } } } // namespace mllm From 6a58c0ca594dec9669b79c23501e54bc0fe68c8a Mon Sep 17 00:00:00 2001 From: Lee <44310445+lx200916@users.noreply.github.com> Date: Wed, 1 Nov 2023 17:14:22 +0800 Subject: [PATCH 3/3] chaos: Rename DataType. feat: Add Quantizer. Signed-off-by: Lee <44310445+lx200916@users.noreply.github.com> --- CMakeLists.txt | 2 +- src/Executor.hpp | 4 +- src/Graph.hpp | 5 +- src/Op.hpp | 10 +-- src/Tensor.hpp | 30 ++------ src/backends/cpu/CPUAttention.cpp | 2 +- src/backends/cpu/CPUAttention.hpp | 4 +- src/quantizer/ParamWriter.cpp | 9 ++- src/quantizer/ParamWriter.hpp | 7 +- src/quantizer/main.cpp | 122 +++++++++++++++++++++++++----- 10 files changed, 132 insertions(+), 63 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 81b45d81..34a69ca4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,7 +124,7 @@ if (QUANT) message(STATUS "MLLM_Quant: ${MLLM_QUANT}") add_executable( - MLLM_QUANT + quantize ${PROJECT_SOURCE_DIR}/src/quantizer/main.cpp ${MLLM_QUANT} ${DIR_SRC} diff --git a/src/Executor.hpp b/src/Executor.hpp index 201c2c40..4a0cc43c 100644 --- a/src/Executor.hpp +++ b/src/Executor.hpp @@ -76,8 +76,8 @@ class Executor { vector> result_; ParamLoader *data_loader_; - mllm_dtype weights_dtype_; - mllm_dtype activation_dtype_; + DataType weights_dtype_; + DataType activation_dtype_; }; } // namespace mllm diff --git a/src/Graph.hpp b/src/Graph.hpp index 3b61c4f4..698209a8 100644 --- a/src/Graph.hpp +++ b/src/Graph.hpp @@ -84,9 +84,8 @@ class Graph { unordered_map> ops_; // opname: op // unordered_map> external_tensors_; - - mllm_dtype weights_dtype_ = MLLM_TYPE_F32; - mllm_dtype activation_dtype_ = MLLM_TYPE_F32; + DataType weights_dtype_ = MLLM_TYPE_F32; + DataType activation_dtype_ = MLLM_TYPE_F32; }; } // namespace mllm diff --git a/src/Op.hpp b/src/Op.hpp index 70436607..f023fa84 100644 --- a/src/Op.hpp +++ b/src/Op.hpp @@ -106,15 +106,15 @@ class Op { return NO_ERROR; } - virtual ErrorCode setDtype(mllm_dtype weight_dtype, mllm_dtype activation_dtype) { + virtual ErrorCode setDtype(DataType weight_dtype, DataType activation_dtype) { weights_dtype_ = weight_dtype; activation_dtype_ = activation_dtype; return NO_ERROR; } - mllm_dtype weightsDtype() const { + DataType weightsDtype() const { return weights_dtype_; } - mllm_dtype activationDtype() const { + DataType activationDtype() const { return activation_dtype_; } /** @@ -143,8 +143,8 @@ class Op { // BackendType backend_type_; // tensor w // vector<> - mllm_dtype weights_dtype_ = MLLM_TYPE_F32; - mllm_dtype activation_dtype_ = MLLM_TYPE_F32; + DataType weights_dtype_ = MLLM_TYPE_F32; + DataType activation_dtype_ = MLLM_TYPE_F32; }; // unordered_map(Backend*)>> opMap; diff --git a/src/Tensor.hpp b/src/Tensor.hpp index 6cde10ee..25404882 100644 --- a/src/Tensor.hpp +++ b/src/Tensor.hpp @@ -41,7 +41,7 @@ class Tensor { void setBackend(Backend *bn) { backend_ = bn; }; - void setDtype(mllm_dtype dtype) { + void setDtype(DataType dtype) { dtype_ = dtype; } @@ -50,7 +50,7 @@ class Tensor { bool reshape(const vector &shape); void alloc(); - void alloc(mllm_dtype dtype) { + void alloc(DataType dtype) { dtype_ = dtype; alloc(); } @@ -281,32 +281,12 @@ class Tensor { } } - mllm_dtype dtype() const { + DataType dtype() const { return dtype_; } float dtypeSize() { - switch (dtype_) { - case MLLM_TYPE_F32: - return sizeof(float); - case MLLM_TYPE_F16: - return sizeof(short); - case MLLM_TYPE_I32: - return sizeof(int); - case MLLM_TYPE_I16: - return sizeof(short); - case MLLM_TYPE_I8: - return sizeof(char); - // TODO WRONG? - case MLLM_TYPE_Q4_0: - return (sizeof(block_q4_0)) / (QK4_0 / 2); - case MLLM_TYPE_Q4_K: - return (sizeof(block_q4_K)) / (QK_K / 2); - case MLLM_TYPE_Q8_0: - return (sizeof(block_q8_0)) / (QK8_0); - case MLLM_TYPE_Q8_K: - return (sizeof(block_q8_K)) / (QK_K); - } + return DataTypeSize(dtype_); } // // void setByteWidth(int bw) { @@ -362,7 +342,7 @@ class Tensor { string name_; // shared_ptr backend_; // int byte_width_; // 32/16/8/4 //enum - mllm_dtype dtype_; + DataType dtype_; Backend *backend_; void *host_ptr_; void *device_ptr_; diff --git a/src/backends/cpu/CPUAttention.cpp b/src/backends/cpu/CPUAttention.cpp index 43e54100..4ef53669 100644 --- a/src/backends/cpu/CPUAttention.cpp +++ b/src/backends/cpu/CPUAttention.cpp @@ -313,7 +313,7 @@ ErrorCode CPUAttention::free(vector> inputs, vectorfree({kqv_state_}, outputs); return Op::free(inputs, outputs); } -ErrorCode CPUAttention::setDtype(mllm_dtype weight_dtype, mllm_dtype activation_dtype) { +ErrorCode CPUAttention::setDtype(DataType weight_dtype, DataType activation_dtype) { Q_proj_->setDtype(weight_dtype, activation_dtype); K_proj_->setDtype(weight_dtype, activation_dtype); V_proj_->setDtype(weight_dtype, activation_dtype); diff --git a/src/backends/cpu/CPUAttention.hpp b/src/backends/cpu/CPUAttention.hpp index 13810b2a..06cfde49 100644 --- a/src/backends/cpu/CPUAttention.hpp +++ b/src/backends/cpu/CPUAttention.hpp @@ -25,8 +25,8 @@ class CPUAttention final : public Op { virtual ErrorCode execute(vector> inputs, vector> outputs) override; virtual ErrorCode reshapeOutputs(vector> inputs, vector> outputs) override; virtual ErrorCode free(vector> inputs, vector> outputs) override; - virtual ErrorCode setDtype(mllm_dtype weight_dtype, mllm_dtype activation_dtype) override; - + virtual ErrorCode setDtype(DataType weight_dtype, DataType activation_dtype) override; + virtual ErrorCode load(ParamLoader &loader) override; private: diff --git a/src/quantizer/ParamWriter.cpp b/src/quantizer/ParamWriter.cpp index 10532c9e..ecedc2bd 100644 --- a/src/quantizer/ParamWriter.cpp +++ b/src/quantizer/ParamWriter.cpp @@ -9,7 +9,10 @@ ParamWriter::ParamWriter(std::string filename) : fp_ = fopen(path_.c_str(), "wb"); writeInt(fp_, _MAGIC_NUMBER); } - +ParamWriter::~ParamWriter() { + if (fp_) + fclose(fp_); +} int ParamWriter::calcIndexSize(const vector names) { int size = 0; for (const auto &name : names) { @@ -28,12 +31,12 @@ void ParamWriter::writeIndex() { } } -void ParamWriter::writeParam(string name, mllm_dtype type, void *data, uint64_t size) { +void ParamWriter::writeParam(string name, DataType type, void *data, uint64_t size) { auto param = param_info_[index_]; param.name = std::move(name); param.type = type; param.offset = ftell(fp_); - fwrite(data, 1, size, fp_); + fwrite(data, sizeof(char), size, fp_); param.size = ftell(fp_) - param.offset; index_++; } diff --git a/src/quantizer/ParamWriter.hpp b/src/quantizer/ParamWriter.hpp index db1a79eb..e0956e31 100644 --- a/src/quantizer/ParamWriter.hpp +++ b/src/quantizer/ParamWriter.hpp @@ -15,22 +15,23 @@ static void writeString(FILE *fp, const std::string &str) { writeInt(fp, str.size()); fwrite(str.c_str(), str.size(), 1, fp); } -static void write_dtype(FILE *fp, mllm_dtype dtype) { +static void write_dtype(FILE *fp, DataType dtype) { writeInt(fp, dtype); } struct ParmInfo { std::string name; - mllm_dtype type; + DataType type; uint64_t offset; uint64_t size; }; class ParamWriter { public: + ~ParamWriter(); ParamWriter(std::string filename); int calcIndexSize(vector names); void writeIndex(); - void writeParam(string name, mllm_dtype type, void *data, uint64_t size); + void writeParam(string name, DataType type, void *data, uint64_t size); private: uint64_t index_ = 0; diff --git a/src/quantizer/main.cpp b/src/quantizer/main.cpp index 1e807719..05658dce 100644 --- a/src/quantizer/main.cpp +++ b/src/quantizer/main.cpp @@ -1,32 +1,53 @@ // // Created by lx on 23-10-31. // -#define NOT_IMPLEMENTED(x) \ - std::cout << "Quantize params to " << #x << " is not implemented\n"; \ - exit(-1); -#define UNREACHABLE() \ - std::cout << "Unreachable code\n"; \ - exit(-1); #include "ParamWriter.hpp" #include "ParamLoader.hpp" +#include "backends/cpu/quantize/QuantizeQ4.hpp" +#include "backends/cpu/quantize/QuantizeQ8.hpp" #include + +#define NOT_IMPLEMENTED(x) \ + std::cout << "Quantize params to " << DataTypeName(x) << " is not implemented\n"; \ + __exit(-1); +#define UNREACHABLE() \ + std::cout << "Unreachable code\n"; \ + __exit(-1); +#define __exit(status) \ + { \ + if (status != 0) { \ + std::cout << "Quantize failed\n"; \ + remove(output_path_.c_str()); \ + } \ + exit(status); \ + } +static std::pair alloc_quant_block(uint64_t count, DataType type) { + uint64_t size = DataTypeSize(type) * count; + if (size <= 0) { + return std::make_pair(nullptr, 0); + } + void *data = new char[size]; + return std::make_pair(data, size); +} namespace mllm { class QuantWriter : public ParamWriter { +public: explicit QuantWriter(std::string output_path, std::string input_path); int ReadParams(); void QuantParams(DataType dataType); private: + string output_path_; mllm::ParamLoader *param_loader_; DataType quant_type_; std::vector param_names_; float *getParam(std::string param_name); }; QuantWriter::QuantWriter(std::string output_path, std::string input_path) : - ParamWriter(std::move(output_path)) { + ParamWriter(output_path), output_path_(output_path) { param_loader_ = new mllm::ParamLoader(std::move(input_path)); if (param_loader_ == nullptr) { - exit(-1); + __exit(-1); } } int QuantWriter::ReadParams() { @@ -36,7 +57,7 @@ int QuantWriter::ReadParams() { } float *QuantWriter::getParam(std::string param_name) { auto type = param_loader_->data_type_[param_name]; - if (type != mllm::DataType::FP32) { + if (type != DataType::MLLM_TYPE_F32) { return nullptr; } void *data = param_loader_->load(param_name); @@ -45,25 +66,90 @@ float *QuantWriter::getParam(std::string param_name) { void QuantWriter::QuantParams(DataType dataType) { quant_type_ = dataType; for (const auto &name : param_names_) { + // int force_quant_type = -1; auto *param = getParam(name); if (param == nullptr) { - exit(-1); + __exit(-1); } + auto size = param_loader_->offsets_[name].second / sizeof(float); + void *quant_ptr = nullptr; + std::pair block_t; switch (dataType) { - case FP32: + case MLLM_TYPE_F32: std::cout << "No need to quantize FP32 params\n"; + __exit(-1); break; - case FP16: - NOT_IMPLEMENTED(FP16); + case MLLM_TYPE_Q4_0: + block_t = alloc_quant_block(size, dataType); + quant_ptr = block_t.first; + quantize_row_q4_0(param, quant_ptr, size); + size = block_t.second; break; - - case INT8: break; - case INT4: - + case MLLM_TYPE_Q8_0: + block_t = alloc_quant_block(size, dataType); + quant_ptr = block_t.first; + quantize_row_q8_0(param, quant_ptr, size); + size = block_t.second; + break; + case MLLM_TYPE_Q4_K: + block_t = alloc_quant_block(size, dataType); + quant_ptr = block_t.first; + quantize_row_q4_K(param, quant_ptr, size); + size = block_t.second; + break; + case MLLM_TYPE_Q8_K: + block_t = alloc_quant_block(size, dataType); + quant_ptr = block_t.first; + quantize_row_q8_K(param, quant_ptr, size); + size = block_t.second; + break; + case MLLM_TYPE_I8: + case MLLM_TYPE_Q4_1: + case MLLM_TYPE_Q8_1: + case MLLM_TYPE_I16: + case MLLM_TYPE_I32: + case MLLM_TYPE_F16: + NOT_IMPLEMENTED(dataType); + break; + case MLLM_TYPE_COUNT: + UNREACHABLE() break; - case DATA_TYPE_COUNT: UNREACHABLE(); break; + } + if (quant_ptr != nullptr) { + writeParam(name, quant_type_, quant_ptr, size); + delete[] (char *)quant_ptr; } } + writeIndex(); } } // namespace mllm +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << "Usage: ./quantize \n"; + return -1; + } + auto input_path = std::string(argv[1]); + auto output_path = std::string(argv[2]); + auto quant_type = std::string(argv[3]); + mllm::QuantWriter quant_writer(output_path, input_path); + int param_count = quant_writer.ReadParams(); + if (param_count <= 0) { + std::cout << "No params to quantize\n"; + return -1; + } + std::cout << "Quantize " << param_count << " params to " << quant_type << "\n"; + if (quant_type == "Q4_0") { + quant_writer.QuantParams(MLLM_TYPE_Q4_0); + } else if (quant_type == "Q8_0") { + quant_writer.QuantParams(MLLM_TYPE_Q8_0); + } else if (quant_type == "Q4_K") { + quant_writer.QuantParams(MLLM_TYPE_Q4_K); + } else if (quant_type == "Q8_K") { + quant_writer.QuantParams(MLLM_TYPE_Q8_K); + } else { + std::cout << "Quant type " << quant_type << " is not supported\n"; + return -1; + } + return 0; +} \ No newline at end of file