Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Stablelm 2 1.6b support #94

Merged
merged 4 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,16 @@ else ()
target_link_libraries(main_imagebind MLLM_CPU)
endif ()


add_executable(demo_stablelm ${PROJECT_SOURCE_DIR}/examples/demo_stablelm.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
src/tokenizers/Tokenizer.cpp
src/tokenizers/BPE/Bpe.cpp
)
if (ARM AND NOT APK)
target_compile_options(demo_stablelm PRIVATE -fopenmp)
target_link_libraries(demo_stablelm PUBLIC MLLM_CPU -fopenmp -static-openmp)
else ()
target_link_libraries(demo_stablelm MLLM_CPU)
endif ()

add_executable(demo_llama ${PROJECT_SOURCE_DIR}/examples/demo_llama.cpp ${DIR_SRC_CPU} ${DIR_SRC_MEM_MANAGER} ${DIR_SRC_EXP} ${DIR_SRC}
src/tokenizers/Tokenizer.cpp
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligen

## Recent update
- [🔥🔥Comming soon] Supporting Qualcomm NPU: [>1000 tokens/second prefilling!](https://arxiv.org/pdf/2407.05858v1)
V1- [2024 July 17] Support new model: StableLM V2 1.6B https://github.com/UbiquitousLearning/mllm/pull/94
- [2024 July 2] Support new model: Yi V1.5 6B https://github.com/UbiquitousLearning/mllm/pull/88
- [2024 May 29] Support new model: Mistral V0.2 7B https://github.com/UbiquitousLearning/mllm/pull/83
- [2024 May 4] Support new model: QWen V1.5 0.5B https://github.com/UbiquitousLearning/mllm/pull/79
Expand Down Expand Up @@ -74,9 +75,10 @@ Wait.. why on-device multimodal LLM? - It's a key building block for [intelligen
| [ImageBind](https://github.com/facebookresearch/ImageBind) (3 modalities) | [✔️](https://huggingface.co/mllmTeam/imagebind_huge-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/imagebind_huge-mllm/tree/main) |
| [LLaVA 7B](https://github.com/haotian-liu/LLaVA) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/llava-1.5-7b-mllm/tree/main) |
| [Gemma 2B](https://github.com/google/gemma_pytorch) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/gemma-2b-mllm/tree/main) |
| [Qwen 0.5B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) |
| [Mistral 7B](https://github.com/mistralai/mistral-src) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) |
| [Yi 6B](https://huggingface.co/01-ai/Yi-1.5-6B) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) |
| [Qwen 0.5B](https://github.com/QwenLM/Qwen) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/qwen-1.5-0.5b-mllm/tree/main) |
| [Mistral 7B](https://github.com/mistralai/mistral-src) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/mistral-7b-instruct-v0.2-mllm/tree/main) |
| [Yi 6B](https://huggingface.co/01-ai/Yi-1.5-6B) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/yi-1.5-6b-chat-mllm/tree/main) |
| [StableLM 1.6B](https://github.com/Stability-AI/StableLM) | [✔️](https://huggingface.co/mllmTeam/stablelm-2-1.6b-chat-mllm/tree/main) | [✔️](https://huggingface.co/mllmTeam/stablelm-2-1.6b-chat-mllm/tree/main) |

## Quick Start

Expand Down
68 changes: 68 additions & 0 deletions examples/demo_stablelm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#include <iostream>
#include "cmdline.h"
#include "models/stablelm/modeling_stablelm.hpp"
#include "models/stablelm/tokenization_stablelm.hpp"
#include "processor/PostProcess.hpp"

using namespace mllm;

int main(int argc, char **argv) {
cmdline::parser cmdParser;
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/stablelm_vocab.mllm");
cmdParser.add<string>("merge", 'm', "specify mllm merge path", false, "../vocab/stablelm_merges.txt");
cmdParser.add<string>("model", 'o', "specify mllm model path", false, "../models/stablelm-2-1.6b-chat-q4_k.mllm");
cmdParser.add<int>("limits", 'l', "max KV cache size", false, 400);
cmdParser.add<int>("thread", 't', "num of threads", false, 4);
cmdParser.parse_check(argc, argv);

string vocab_path = cmdParser.get<string>("vocab");
string merge_path = cmdParser.get<string>("merge");
string model_path = cmdParser.get<string>("model");
int tokens_limit = cmdParser.get<int>("limits");
CPUBackend::cpu_threads = cmdParser.get<int>("thread");

auto tokenizer = StableLMTokenizer(vocab_path, merge_path);

string system_prompt_start = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n";
string system_prompt_end = "<|im_end|>\n<|im_start|>assistant\n";

StableLMConfig config(tokens_limit, "1.6B", HFHUBROPE);
auto model = StableLMModel(config);
model.load(model_path);

vector<string> in_strs = {
" Hello, who are you?",
" What can you do?",
"Please introduce Beijing University of Posts and Telecommunications."};

for (int i = 0; i < in_strs.size(); ++i) {
const auto& in_str_origin = in_strs[i];
auto in_str = system_prompt_start + in_str_origin + system_prompt_end;
std::cout << "[Q] " << in_str_origin << std::endl;
auto input_tensor = tokenizer.tokenize(in_str, i);
std::cout << "[A] " << std::flush;
for (int step = 0; step < 100; step++) {
auto result = model({input_tensor});
auto outputs = tokenizer.detokenize(result[0]);
auto out_string = outputs.first;
auto out_token = outputs.second;
if (out_token == 100278) {
break;
}
size_t pos = 0;
while ((pos = out_string.find("Ċ", pos)) != std::string::npos) {
out_string.replace(pos, 2, " ");
}
pos = 0;
while ((pos = out_string.find("Ġ", pos)) != std::string::npos) {
out_string.replace(pos, 2, " ");
}

std::cout << out_string << std::flush;
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
}

return 0;
}
7 changes: 7 additions & 0 deletions src/Layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,13 @@ class RoPE final : public Layer {
param_["max_position_embeddings"] = max_position_embeddings;
init(std::move(name), OpType::ROPE);
}
explicit RoPE(int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, std::string name) {
param_["pose_type"] = pose_type;
param_["rope_theta"] = rope_theta;
param_["max_position_embeddings"] = max_position_embeddings;
param_["partial_rotary_factor"] = partial_rotary_factor;
init(std::move(name), OpType::ROPE);
}
Tensor &operator()(Tensor &input) {
return _1I1O_OP(input);
}
Expand Down
43 changes: 34 additions & 9 deletions src/backends/cpu/CPURoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,21 @@ CPURoPE::CPURoPE(Backend *bn, string opName, int pose_type, float rope_theta, in
pos_max_ = max_position_embeddings;
}

CPURoPE::CPURoPE(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, int threadCount) :
thread_count(threadCount),
Op(bn, opName) {
pose_type_ = pose_type;
rope_theta_ = rope_theta;
partial_rotary_factor_ = partial_rotary_factor;
pos_max_ = max_position_embeddings;
}

ErrorCode CPURoPE::reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) {
// std::cout << name() << " CPURoPE reshape" << std::endl;
assert(inputs.size() == 1);
assert(outputs.size() == 1);
outputs[0]->reshape(inputs[0]->batch(), inputs[0]->head(), inputs[0]->sequence(), inputs[0]->dimension());
ishape = inputs[0]->dimension();
ishape = inputs[0]->dimension() * partial_rotary_factor_;
// pos_max_ = 16384;
if (sin_.empty() || ishape_old < ishape || global_pose_type_ != pose_type_) {
global_pose_type_ = pose_type_;
Expand All @@ -102,11 +111,12 @@ ErrorCode CPURoPE::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
auto &input = inputs[0];
auto &output = outputs[0];
auto out_dtype = output->dtype();
int partial_dimension = (input->dimension()) * partial_rotary_factor_;
for (int n = 0; n < input->batch(); ++n) {
for (int h = 0; h < input->head(); ++h) {
for (int s = 0; s < input->sequence(); ++s) { // sequance
#pragma omp parallel for num_threads(thread_count)
for (int d = 0; d < input->dimension(); ++d) {
for (int d = 0; d < partial_dimension; ++d) {
if (pose_type_ == LLAMAROPE) {
float in_value = input->dataAt<float>(n, h, s, d);
float in_value_2;
Expand All @@ -128,16 +138,16 @@ ErrorCode CPURoPE::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
float in_value_2;
float sin_value = sin_[s + h_cnt_][d];
float cos_value = cos_[s + h_cnt_][d];
if (d < input->dimension() / 4) {
in_value_2 = -input->dataAt<float>(n, h, s, d + input->dimension() / 4);
if (d < partial_dimension / 4) {
in_value_2 = -input->dataAt<float>(n, h, s, d + partial_dimension / 4);
auto value = in_value * cos_value + in_value_2 * sin_value;
if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) {
output->setDataAt<float>(n, h, s, d, value);
} else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) {
output->setDataAt<mllm_fp16_t>(n, h, s, d, MLLM_FP32_TO_FP16(value));
}
} else if (d < input->dimension() / 2) {
in_value_2 = input->dataAt<float>(n, h, s, d - input->dimension() / 4);
} else if (d < (partial_dimension / 2)) {
in_value_2 = input->dataAt<float>(n, h, s, d - partial_dimension / 4);
auto value = in_value * cos_value + in_value_2 * sin_value;
if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) {
output->setDataAt<float>(n, h, s, d, value);
Expand All @@ -154,10 +164,10 @@ ErrorCode CPURoPE::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
} else if (pose_type_ == HFHUBROPE) {
float in_value = input->dataAt<float>(n, h, s, d);
float in_value_2;
if (d < input->dimension() / 2) {
in_value_2 = -input->dataAt<float>(n, h, s, d + input->dimension() / 2);
if (d < (partial_dimension / 2)) {
in_value_2 = -input->dataAt<float>(n, h, s, d + partial_dimension / 2);
} else {
in_value_2 = input->dataAt<float>(n, h, s, d - input->dimension() / 2);
in_value_2 = input->dataAt<float>(n, h, s, d - partial_dimension / 2);
}
float sin_value = sin_[s + h_cnt_][d];
float cos_value = cos_[s + h_cnt_][d];
Expand Down Expand Up @@ -201,6 +211,21 @@ ErrorCode CPURoPE::execute(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<
if (h_cnt_ > pos_max_) {
h_cnt_ = 0;
}

for (int n = 0; n < input->batch(); ++n) {
for (int h = 0; h < input->head(); ++h) {
for (int s = 0; s < input->sequence(); ++s) {
#pragma omp parallel for num_threads(thread_count)
for (int d = partial_dimension; d < input->dimension(); ++d) {
if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) {
output->setDataAt<float>(n, h, s, d, input->dataAt<float>(n, h, s, d));
} else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) {
output->setDataAt<mllm_fp16_t>(n, h, s, d, MLLM_FP32_TO_FP16(input->dataAt<float>(n, h, s, d)));
}
}
}
}
}
return Op::execute(inputs, outputs);
}

Expand Down
8 changes: 7 additions & 1 deletion src/backends/cpu/CPURoPE.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class CPURoPE final : public Op {
public:
CPURoPE(Backend *bn, string opName, int pose_type, int threadCount);
CPURoPE(Backend *bn, string opName, int pose_type, float rope_theta, int max_position_embeddings, int threadCount);
CPURoPE(Backend *bn, string opName, int pose_type, float rope_theta, float partial_rotary_factor, int max_position_embeddings, int threadCount);
virtual ~CPURoPE() = default;
virtual ErrorCode reshape(vector<shared_ptr<Tensor>> inputs, vector<shared_ptr<Tensor>> outputs) override;
virtual ErrorCode load(AbstructLoader &loader) override;
Expand All @@ -30,6 +31,7 @@ class CPURoPE final : public Op {
int pose_type_ = 4;
int ishape;
int thread_count = 4;
float partial_rotary_factor_ = 1;
};

class CPURoPECreator : public CPUBackend::Creator {
Expand All @@ -41,7 +43,11 @@ class CPURoPECreator : public CPUBackend::Creator {
}
float rope_theta = op_param["rope_theta"];
int max_position_embeddings = op_param["max_position_embeddings"];
return new CPURoPE(bn, name, pose_type, rope_theta, max_position_embeddings, threadCount);
if (op_param.find("partial_rotary_factor") == op_param.end()) {
return new CPURoPE(bn, name, pose_type, rope_theta, max_position_embeddings, threadCount);
}
float partial_rotary_factor = op_param["partial_rotary_factor"];
return new CPURoPE(bn, name, pose_type, rope_theta, partial_rotary_factor, max_position_embeddings, threadCount);
}
};
} // namespace mllm
Expand Down
69 changes: 69 additions & 0 deletions src/models/stablelm/configuration_stablelm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef CONFIG_STABLELM_HPP
#define CONFIG_STABLELM_HPP
#include "models/transformer/configuration_transformer.hpp"

using namespace mllm;

class stablelmNameConfig : public TransformerNameConfig {
public:
std::string blk_name;
std::string token_embd_name;
std::string post_norm_name;
std::string lm_head_name;
std::string _gate_proj_name;

void init(RoPEType type = HFHUBROPE) {
switch (type) {
case HFHUBROPE: {
blk_name = "model.layers.";
_attn_base_name = "self_attn.";
_ffn_base_name = "mlp.";
_q_proj_name = "q_proj";
_k_proj_name = "k_proj";
_v_proj_name = "v_proj";
_o_proj_name = "o_proj";
_gate_proj_name = "gate_proj";
_up_proj_name = "up_proj";
_down_proj_name = "down_proj";
_attn_norm_name = "input_layernorm";
_ffn_norm_name = "post_attention_layernorm";
token_embd_name = "model.embed_tokens";
post_norm_name = "model.norm";
lm_head_name = "lm_head";
break;
}
default: {
throw std::runtime_error("Unsupported llama type");
}
}
}
};

class StableLMConfig {
public:
int vocab_size{};
int hidden_dim{};
int head_size{};
int ffn_hidden{};
int block_num{};
RoPEType RoPE_type;
int cache_limit{};
stablelmNameConfig names_config;

explicit StableLMConfig(int token_limit, string billions = "1.6B", RoPEType type = HFHUBROPE, int vocab = 100352) {
names_config.init(type);
vocab_size = vocab;
if (billions == "1.6B" || billions == "1.6b") {
hidden_dim = 2048;
head_size = 32;
ffn_hidden = 5632;
block_num = 24;
} else {
throw std::runtime_error("Unsupported model size");
}
RoPE_type = type;
cache_limit = token_limit;
}
};

#endif //
Loading
Loading