-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjina_bert.h
84 lines (66 loc) · 2.12 KB
/
jina_bert.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#pragma once
#include "bert.h"
namespace embeddings {
struct JinaBertConfig {
int32_t vocab_size;
int32_t hidden_size;
int32_t num_hidden_layers;
int32_t num_attention_heads;
int32_t intermediate_size;
int32_t type_vocab_size;
int32_t pad_token_id;
float_t layer_norm_eps;
};
struct JinaBertEmbedding {
struct ggml_tensor *word_embeddings;
struct ggml_tensor *token_type_embeddings;
struct ggml_tensor *ln_e_w;
struct ggml_tensor *ln_e_b;
};
struct JinaEncoderBlock {
// attention
struct ggml_tensor *Wqkv_w;
struct ggml_tensor *Wqkv_b;
struct ggml_tensor *o_w;
struct ggml_tensor *o_b;
struct ggml_tensor *norm1_w;
struct ggml_tensor *norm1_b;
// glumlp
struct ggml_tensor *mlp_gated_layers_w;
struct ggml_tensor *mlp_out_w;
struct ggml_tensor *mlp_out_b;
struct ggml_tensor *norm2_w;
struct ggml_tensor *norm2_b;
};
class JinaBertModel {
public:
JinaBertModel(const std::string &);
std::vector<float> Forward(const Encoding &, bool normalize = true,
int pooling_method = 0);
std::vector<std::vector<float>> BatchForward(const std::vector<Encoding> &,
bool normalize = true,
int pooling_method = 0);
private:
struct ggml_cgraph *BuildGraph(const std::vector<Encoding> &batch,
bool normalize = true, int pooling_method = 0);
void Clear();
std::string arch;
JinaBertConfig hparams;
BackendContext ctx;
JinaBertEmbedding embeddings;
std::vector<JinaEncoderBlock> layers;
};
class JinaEmbedding {
public:
JinaEmbedding(const std::string &hf_token_json,
const std::string &gguf_model);
std::vector<float> Encode(const std::string &, bool normalize = true,
int pooling_method = 0);
std::vector<std::vector<float>> BatchEncode(const std::vector<std::string> &,
bool normalize = true,
int pooling_method = 0);
private:
Tokenizer *tok;
JinaBertModel *model;
};
} // namespace embeddings