@@ -26,20 +26,7 @@ Forked with gratitude from:
26
26
27
27
#define BERT_MAX_NODES 4096
28
28
29
- // model keys
30
-
31
- #define KEY_FTYPE " general.file_type"
32
- #define KEY_NAME " general.name"
33
- #define KEY_DESCRIPTION " general.description"
34
-
35
- #define KEY_PAD_ID " tokenizer.ggml.padding_token_id"
36
- #define KEY_UNK_ID " tokenizer.ggml.unknown_token_id"
37
- #define KEY_BOS_ID " tokenizer.ggml.bos_token_id"
38
- #define KEY_EOS_ID " tokenizer.ggml.eos_token_id"
39
- #define KEY_SUBWORD_PREFIX " tokenizer.ggml.subword_prefix"
40
- #define KEY_TOKEN_LIST " tokenizer.ggml.tokens"
41
-
42
- const int verbosity = 0 ;
29
+ const int verbosity = 1 ;
43
30
44
31
//
45
32
// utilities to get data from a gguf file
@@ -73,9 +60,11 @@ static float get_f32(const gguf_context * ctx, const std::string & key) {
73
60
return gguf_get_val_f32 (ctx, i);
74
61
}
75
62
76
- static std::string get_str (const gguf_context * ctx, const std::string & key) {
77
- const int i = get_key_idx (ctx, key.c_str ());
78
-
63
+ static std::string get_str (const gguf_context * ctx, const std::string & key, const std::string & def = " " ) {
64
+ const int i = gguf_find_key (ctx, key.c_str ());
65
+ if (i == -1 ) {
66
+ return def;
67
+ }
79
68
return gguf_get_val_str (ctx, i);
80
69
}
81
70
@@ -325,16 +314,22 @@ bert_tokens bert_tokenize(struct bert_ctx * ctx, bert_string text, uint64_t n_ma
325
314
return tokens;
326
315
}
327
316
328
- bert_string bert_detokenize (struct bert_ctx * ctx, bert_tokens tokens, bool debug) {
317
+ bert_string bert_detokenize (struct bert_ctx * ctx, bert_tokens tokens, bool debug = false ) {
329
318
const bert_token bos_id = ctx->vocab .bos_id ;
330
319
const bert_token eos_id = ctx->vocab .eos_id ;
320
+
321
+ const std::string word_prefix = ctx->vocab .word_prefix ;
331
322
const std::string subword_prefix = ctx->vocab .subword_prefix ;
332
- const std::string prefix = subword_prefix + subword_prefix;
323
+ const uint32_t word_prefix_len = word_prefix.size ();
324
+ const uint32_t subword_prefix_len = subword_prefix.size ();
333
325
334
326
bert_string str = " " ;
335
327
for (const uint64_t &t : tokens) {
336
328
std::string token = bert_vocab_id_to_token (ctx, t);
337
- bool subword = token.find (prefix) == 0 ;
329
+ bool subword = (
330
+ (subword_prefix_len > 0 && token.find (subword_prefix) == 0 ) ||
331
+ (word_prefix_len > 0 && token.find (word_prefix) != 0 )
332
+ );
338
333
if (debug) {
339
334
if ((str.size () > 0 ) && !subword) {
340
335
str += " " ;
@@ -345,12 +340,12 @@ bert_string bert_detokenize(struct bert_ctx * ctx, bert_tokens tokens, bool debu
345
340
continue ;
346
341
}
347
342
if (subword) {
348
- str += token.substr (2 );
343
+ str += token.substr (subword_prefix_len );
349
344
} else {
350
345
if (str.size () > 0 ) {
351
346
str += " " ;
352
347
}
353
- str += token;
348
+ str += token. substr (word_prefix_len) ;
354
349
}
355
350
}
356
351
}
@@ -462,8 +457,11 @@ struct bert_ctx * bert_load_from_file(const char *fname, bool use_cpu) {
462
457
vocab.unk_id = get_i32 (ctx_gguf, KEY_UNK_ID);
463
458
vocab.bos_id = get_i32 (ctx_gguf, KEY_BOS_ID);
464
459
vocab.eos_id = get_i32 (ctx_gguf, KEY_EOS_ID);
460
+
461
+ vocab.word_prefix = get_str (ctx_gguf, KEY_WORD_PREFIX);
465
462
vocab.subword_prefix = get_str (ctx_gguf, KEY_SUBWORD_PREFIX);
466
- uint32_t prefix_len = vocab.subword_prefix .size ();
463
+ uint32_t word_prefix_len = vocab.word_prefix .size ();
464
+ uint32_t subword_prefix_len = vocab.subword_prefix .size ();
467
465
468
466
const int token_idx = gguf_find_key (ctx_gguf, KEY_TOKEN_LIST);
469
467
const int n_vocab = gguf_get_arr_n (ctx_gguf, token_idx);
@@ -472,20 +470,25 @@ struct bert_ctx * bert_load_from_file(const char *fname, bool use_cpu) {
472
470
std::string word = gguf_get_arr_str (ctx_gguf, token_idx, i);
473
471
vocab.tokens .push_back (word);
474
472
475
- if (word. find (vocab. subword_prefix ) == 0 ) {
476
- vocab. subword_token_to_id [ word.substr (prefix_len)] = i;
477
- vocab._id_to_subword_token [i] = word;
478
- }
473
+ bool subword = (
474
+ (subword_prefix_len > 0 && word.find (vocab. subword_prefix ) == 0 ) ||
475
+ (word_prefix_len > 0 && word. find ( vocab.word_prefix ) != 0 )
476
+ );
479
477
480
- if (vocab.token_to_id .count (word) == 0 ) {
481
- vocab.token_to_id [word] = i;
478
+ if (subword) {
479
+ vocab.subword_token_to_id [word.substr (subword_prefix_len)] = i;
480
+ vocab._id_to_subword_token [i] = word;
481
+ } else {
482
+ vocab.token_to_id [word.substr (word_prefix_len)] = i;
482
483
vocab._id_to_token [i] = word;
483
484
}
484
485
}
485
486
486
487
if (verbosity >= 1 ) {
487
488
fprintf (stderr, " %s: TOKENIZER\n " , __func__);
488
489
fprintf (stderr, " %s: vocab size: %d\n " , __func__, n_vocab);
490
+ fprintf (stderr, " %s: word_prefix: %s\n " , __func__, vocab.word_prefix .c_str ());
491
+ fprintf (stderr, " %s: subword_prefix: %s\n " , __func__, vocab.subword_prefix .c_str ());
489
492
fprintf (stderr, " %s: pad_id = %d\n " , __func__, vocab.pad_id );
490
493
fprintf (stderr, " %s: unk_id = %d\n " , __func__, vocab.unk_id );
491
494
fprintf (stderr, " %s: bos_id = %d\n " , __func__, vocab.bos_id );
@@ -627,13 +630,6 @@ struct bert_ctx * bert_load_from_file(const char *fname, bool use_cpu) {
627
630
bert_layer & layer = model.layers [i];
628
631
std::string pre = " encoder.layer." + std::to_string (i) + " ." ;
629
632
630
- // normalization
631
- layer.ln_att_w = get_tensor (new_bert->ctx_data , pre + " attention.output.LayerNorm.weight" );
632
- layer.ln_att_b = get_tensor (new_bert->ctx_data , pre + " attention.output.LayerNorm.bias" );
633
-
634
- layer.ln_out_w = get_tensor (new_bert->ctx_data , pre + " output.LayerNorm.weight" );
635
- layer.ln_out_b = get_tensor (new_bert->ctx_data , pre + " output.LayerNorm.bias" );
636
-
637
633
// attention
638
634
layer.q_w = get_tensor (new_bert->ctx_data , pre + " attention.self.query.weight" );
639
635
layer.q_b = get_tensor (new_bert->ctx_data , pre + " attention.self.query.bias" );
@@ -645,12 +641,18 @@ struct bert_ctx * bert_load_from_file(const char *fname, bool use_cpu) {
645
641
layer.o_w = get_tensor (new_bert->ctx_data , pre + " attention.output.dense.weight" );
646
642
layer.o_b = get_tensor (new_bert->ctx_data , pre + " attention.output.dense.bias" );
647
643
644
+ layer.ln_att_w = get_tensor (new_bert->ctx_data , pre + " attention.output.LayerNorm.weight" );
645
+ layer.ln_att_b = get_tensor (new_bert->ctx_data , pre + " attention.output.LayerNorm.bias" );
646
+
648
647
// ff
649
648
layer.ff_i_w = get_tensor (new_bert->ctx_data , pre + " intermediate.dense.weight" );
650
649
layer.ff_i_b = get_tensor (new_bert->ctx_data , pre + " intermediate.dense.bias" );
651
650
652
651
layer.ff_o_w = get_tensor (new_bert->ctx_data , pre + " output.dense.weight" );
653
652
layer.ff_o_b = get_tensor (new_bert->ctx_data , pre + " output.dense.bias" );
653
+
654
+ layer.ln_out_w = get_tensor (new_bert->ctx_data , pre + " output.LayerNorm.weight" );
655
+ layer.ln_out_b = get_tensor (new_bert->ctx_data , pre + " output.LayerNorm.bias" );
654
656
}
655
657
}
656
658
0 commit comments