-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbert.h
124 lines (102 loc) · 3.19 KB
/
bert.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <cmath>
#include <string>
#include <vector>
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "ggml.h"
#include "tokenizer.h"
#define KEY_FTYPE "general.file_type"
#define KEY_NAME "general.name"
#define KEY_DESCRIPTION "general.description"
#define KEY_ARCHITECTURE "general.architecture"
#define ARCH_XLMROBERTA "XLMRobertaModel"
#define POOLING_METHOD_MEAN 0
#define POOLING_METHOD_CLS 1
namespace embeddings {
struct BertConfig {
int32_t vocab_size;
int32_t max_position_embedding;
int32_t hidden_size;
int32_t intermediate_size;
int32_t num_attention_heads;
int32_t num_hidden_layers;
float_t layer_norm_eps;
};
class BackendContext {
public:
// ggml context for weights
struct ggml_context *ctx_data = NULL;
// memory buffers to evaluate the model
ggml_backend_t backend = NULL;
ggml_backend_buffer_t weights_buffer = NULL;
// load tokens into here, to compute
struct ggml_context *compute_ctx = NULL;
ggml_backend_buffer_t compute_buffer = NULL;
// the compute graph for each forward process
struct ggml_context *compute_graph_ctx = NULL;
ggml_gallocr_t compute_allocr = NULL;
};
struct BertEmbedding {
struct ggml_tensor *word_embeddings;
struct ggml_tensor *token_type_embeddings;
struct ggml_tensor *position_embeddings;
struct ggml_tensor *ln_e_w;
struct ggml_tensor *ln_e_b;
struct ggml_tensor *pooler_e_w;
struct ggml_tensor *pooler_e_b;
};
struct EncoderBlock {
// attention
struct ggml_tensor *q_w;
struct ggml_tensor *q_b;
struct ggml_tensor *k_w;
struct ggml_tensor *k_b;
struct ggml_tensor *v_w;
struct ggml_tensor *v_b;
struct ggml_tensor *o_w;
struct ggml_tensor *o_b;
struct ggml_tensor *ln_att_w;
struct ggml_tensor *ln_att_b;
// ff
struct ggml_tensor *ff_i_w;
struct ggml_tensor *ff_i_b;
struct ggml_tensor *ff_o_w;
struct ggml_tensor *ff_o_b;
struct ggml_tensor *ln_out_w;
struct ggml_tensor *ln_out_b;
};
class BertModel {
public:
BertModel(const std::string &);
std::vector<float> Forward(const Encoding &, bool normalize = true,
int pooling_method = 0);
std::vector<std::vector<float>> BatchForward(const std::vector<Encoding> &,
bool normalize = true,
int pooling_method = 0);
private:
struct ggml_cgraph *BuildGraph(const std::vector<Encoding> &batch,
bool normalize = true, int pooling_method = 0);
void Clear();
std::string arch;
BertConfig hparams;
BackendContext ctx;
BertEmbedding embeddings;
std::vector<EncoderBlock> layers;
};
class Embedding {
public:
Embedding(const std::string &hf_token_json, const std::string &gguf_model);
std::vector<float> Encode(const std::string &, bool normalize = true,
int pooling_method = 0);
std::vector<std::vector<float>> BatchEncode(const std::vector<std::string> &,
bool normalize = true,
int pooling_method = 0);
private:
Tokenizer *tok;
BertModel *model;
};
} // namespace embeddings