-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
604 additions
and
296 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// | ||
// Created by xwk on 24-10-23. | ||
// | ||
#include "models/bert/configuration_bert.hpp" | ||
#include "models/bert/modeling_bert.hpp" | ||
#include "models/bert/tokenization_bert.hpp" | ||
#include "cmdline.h" | ||
|
||
/* | ||
* an intent to support gte-small BertModel to do text embedding | ||
* current implementation is just a very basic example with a simple WordPiece tokenizer and a simple BertModel | ||
* not support batch embedding | ||
* */ | ||
|
||
int main(int argc, char *argv[]) { | ||
cmdline::parser cmdParser; | ||
cmdParser.add<string>("model", 'm', "specify mllm model path", false, "../models/gte-small-fp32.mllm"); | ||
cmdParser.add<string>("vocab", 'v', "specify mllm tokenizer model path", false, "../vocab/gte_vocab.mllm"); | ||
cmdParser.add<int>("thread", 't', "num of threads", false, 4); | ||
cmdParser.parse_check(argc, argv); | ||
|
||
string model_path = cmdParser.get<string>("model"); | ||
string vocab_path = cmdParser.get<string>("vocab"); | ||
CPUBackend::cpu_threads = cmdParser.get<int>("thread"); | ||
|
||
BertTokenizer tokenizer(vocab_path, true); | ||
string text = "Help me set an alarm at 21:30"; | ||
auto [token_ids, type_ids, position_ids] = tokenizer.process(text); | ||
// token_ids.printData<float>(); | ||
|
||
auto config = BertConfig(); | ||
auto model = BertModel(config); | ||
model.load(model_path); | ||
|
||
auto res = model({token_ids, type_ids, position_ids})[0]; | ||
|
||
res.printData<float>(); | ||
|
||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,4 +65,4 @@ int main(int argc, char **argv) { | |
}); | ||
std::cout << "\n"; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#ifndef CONFIG_BERT_HPP | ||
#define CONFIG_BERT_HPP | ||
#include "Types.hpp" | ||
#include "models/transformer/configuration_transformer.hpp" | ||
#include <cctype> | ||
#include <iterator> | ||
|
||
using namespace mllm; | ||
|
||
class BertNameConfig : public TransformerNameConfig { | ||
public: | ||
void init() { | ||
embedding_base_name = "embeddings."; | ||
|
||
blk_name = "encoder.layer."; | ||
_attn_base_name = "attention."; | ||
_q_proj_name = "self.query"; | ||
_k_proj_name = "self.key"; | ||
_v_proj_name = "self.value"; | ||
_o_proj_name = "output.dense"; | ||
_up_proj_name = "intermediate.dense"; | ||
_down_proj_name = "output.dense"; | ||
_attn_norm_name = "output.LayerNorm"; | ||
_ffn_norm_name = "output.LayerNorm"; | ||
} | ||
std::string embedding_base_name; | ||
|
||
std::string blk_name; | ||
}; | ||
|
||
struct BertConfig : public TransformerConfig { | ||
explicit BertConfig() { | ||
hidden_act = "GELU"; | ||
pooling_type = "mean"; | ||
hidden_size = 384; | ||
intermediate_size = 1536; | ||
max_position_embeddings = 512; | ||
model_type = "bert"; | ||
num_attention_heads = 12; | ||
num_hidden_layers = 12; | ||
vocab_size = 30522; | ||
names_config.init(); | ||
}; | ||
|
||
int type_vocab_size = 2; | ||
float layer_norm_eps = 1e-12; | ||
|
||
std::string hidden_act = "GELU"; | ||
std::string pooling_type = "mean"; | ||
int hidden_size = 1024; | ||
int intermediate_size = 2816; | ||
int max_position_embeddings = 32768; | ||
std::string model_type = "bert"; | ||
int num_attention_heads = 12; | ||
int num_hidden_layers = 12; | ||
|
||
|
||
int vocab_size = 151936; | ||
|
||
BertNameConfig names_config; | ||
}; | ||
|
||
#endif //! CONFIG_BERT_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#ifndef MODELING_BERT_HPP | ||
#define MODELING_BERT_HPP | ||
|
||
#include "Backend.hpp" | ||
#include "Layer.hpp" | ||
#include "Module.hpp" | ||
#include "Tensor.hpp" | ||
#include "configuration_bert.hpp" | ||
#include "models/transformer/modeling_transformer.hpp" | ||
using namespace mllm; | ||
|
||
class BertEmbeddings : public Module { | ||
public: | ||
BertEmbeddings() = default; | ||
BertEmbeddings(int vocal_size, int hidden_size, int type_size, int max_position_embeddings, float eps, BertNameConfig &config) { | ||
word_embeddings = Embedding(vocal_size, hidden_size, config.embedding_base_name + "word_embeddings"); | ||
token_type_embeddings = Embedding(type_size, hidden_size, config.embedding_base_name + "token_type_embeddings"); | ||
position_embeddings = Embedding(max_position_embeddings, hidden_size, config.embedding_base_name + "position_embeddings"); | ||
layer_norm = LayerNorm(hidden_size, true, eps, config.embedding_base_name + "LayerNorm"); | ||
} | ||
|
||
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override { | ||
auto inputs_embeds = word_embeddings(inputs[0]); | ||
auto type_embeds = token_type_embeddings(inputs[1]); | ||
auto position_embeds = position_embeddings(inputs[2]); | ||
auto embeddings = inputs_embeds + type_embeds + position_embeds; | ||
return {layer_norm(embeddings)}; | ||
} | ||
|
||
private: | ||
Layer word_embeddings; | ||
Layer token_type_embeddings; | ||
Layer position_embeddings; | ||
Layer layer_norm; | ||
}; | ||
|
||
class BertLayer : public Module { | ||
public: | ||
BertLayer() = default; | ||
BertLayer(const BertConfig &config, const string &base_name) { | ||
// base_name: encoder.layer.n. | ||
attention = MultiHeadAttention(config.hidden_size, config.num_attention_heads, config.num_attention_heads, | ||
config.hidden_size / config.num_attention_heads, SPLIT_NONE, false, false, RoPEType::NONE, -1, -1, 0, false, true, config.names_config, | ||
base_name + config.names_config._attn_base_name); | ||
|
||
feed_forward = FeedForward(config.hidden_size, config.intermediate_size, | ||
config.hidden_act, true, config.names_config, base_name); | ||
|
||
attn_norm = LayerNorm(config.hidden_size, true, config.layer_norm_eps, | ||
base_name + config.names_config._attn_base_name + config.names_config._attn_norm_name); | ||
|
||
ff_norm = LayerNorm(config.hidden_size, true, config.layer_norm_eps, | ||
base_name + config.names_config._ffn_norm_name); | ||
} | ||
|
||
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override { | ||
auto hidden_states = inputs[0]; | ||
|
||
auto attn_out = attention({hidden_states, hidden_states, hidden_states})[0]; | ||
|
||
hidden_states = attn_norm({hidden_states + attn_out}); | ||
|
||
auto ff_out = feed_forward({hidden_states})[0]; | ||
|
||
hidden_states = ff_norm({hidden_states + ff_out}); | ||
|
||
return {hidden_states}; | ||
} | ||
|
||
private: | ||
MultiHeadAttention attention; | ||
FeedForward feed_forward; | ||
|
||
Layer attn_norm, ff_norm; | ||
}; | ||
|
||
class BertAvgPooler : public Module { | ||
public: | ||
BertAvgPooler() = default; | ||
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override { | ||
auto x = inputs[0]; | ||
x = x.mean(SEQUENCE); | ||
return {x}; | ||
} | ||
}; | ||
|
||
class BertModel : public Module { | ||
public: | ||
BertModel(BertConfig &config) { | ||
embeddings = BertEmbeddings(config.vocab_size, config.hidden_size, config.type_vocab_size, config.max_position_embeddings, config.layer_norm_eps, config.names_config); | ||
layers = List<BertLayer>(config.num_hidden_layers, config, config.names_config.blk_name); | ||
|
||
if (config.pooling_type == "mean") { | ||
pooler = BertAvgPooler(); | ||
} else { | ||
// print not support pooling type and exit | ||
std::cout << "Not support pooling type: " << config.pooling_type << std::endl; | ||
exit(0); | ||
} | ||
} | ||
|
||
std::vector<Tensor> Forward(std::vector<Tensor> inputs, std::vector<std::any> args) override { | ||
auto x = embeddings(inputs, args)[0]; | ||
for (auto &layer : layers) { | ||
x = layer({x})[0]; | ||
} | ||
x = pooler({x})[0]; | ||
return {x}; | ||
} | ||
|
||
private: | ||
BertEmbeddings embeddings; | ||
std::vector<BertLayer> layers; | ||
BertAvgPooler pooler; | ||
}; | ||
|
||
#endif //! MODELING_BERT_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#ifndef TOKENIZATION_BERT_HPP | ||
#define TOKENIZATION_BERT_HPP | ||
|
||
#include "tokenizers/BPE/Bpe.hpp" | ||
#include "tokenizers/Tokenizer.hpp" | ||
#include "tokenizers/Unicode.hpp" | ||
#include "tokenizers/WordPiece/WordPiece.hpp" | ||
#include <algorithm> | ||
#include <unordered_map> | ||
|
||
// unicode | ||
#include <codecvt> | ||
|
||
using namespace mllm; | ||
|
||
|
||
class BertTokenizer final : public WordPieceTokenizer { | ||
public: | ||
explicit BertTokenizer(const std::string &vocab_file, bool add_special_tokens = true) : | ||
WordPieceTokenizer(vocab_file) { | ||
Module::initBackend(MLLM_CPU); | ||
_add_special_tokens = add_special_tokens; | ||
this->add_special_tokens({"[PAD]", "[CLS]", "[SEP]", "[MASK]"}); | ||
} | ||
std::tuple<Tensor, Tensor, Tensor> process(std::string text){ | ||
if (_add_special_tokens) { | ||
text = "[CLS] " + text + " [SEP]"; | ||
} | ||
auto tokens_id = vector<token_id_t>(); | ||
WordPieceTokenizer::tokenize(text, tokens_id, false); | ||
// printf("token: "); | ||
// for (auto &token_id : tokens_id) { | ||
// printf("%d ", token_id); | ||
// } | ||
printf("\n"); | ||
auto tokens_type = vector<token_id_t>(tokens_id.size(), 0); | ||
auto position_ids = vector<token_id_t>(tokens_id.size()); | ||
for (size_t i = 0; i < tokens_id.size(); i++) { | ||
position_ids[i] = i; | ||
} | ||
return { | ||
tokens2Input(tokens_id, "input_tokens"), | ||
tokens2Input(tokens_type, "input_tokens_type"), | ||
tokens2Input(position_ids, "input_position_ids") | ||
}; | ||
} | ||
|
||
private: | ||
bool _add_special_tokens; | ||
}; | ||
|
||
#endif //! TOKENIZATION_BERT_HPP |
Oops, something went wrong.