From 52c9397495688b610bbf7afd419f35b2b248ea69 Mon Sep 17 00:00:00 2001 From: 9173860 Date: Wed, 14 Sep 2022 21:10:54 +0800 Subject: [PATCH 1/7] WIP: add special token --- tests/unit_tests/test_manual.py | 23 +++++++++++++++++++++++ youtokentome/cpp/utils.cpp | 20 ++++++++++++++++---- youtokentome/cpp/utils.h | 3 +++ youtokentome/cpp/yttm.pyx | 3 +++ youtokentome/youtokentome.py | 2 ++ 5 files changed, 47 insertions(+), 4 deletions(-) diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/test_manual.py index c4f7c9d..607fbd2 100644 --- a/tests/unit_tests/test_manual.py +++ b/tests/unit_tests/test_manual.py @@ -73,3 +73,26 @@ def test_japanese(): assert tokenized_text == expected_result print(tokenized_text) os.remove(TRAIN_DATA_PATH) + +def test_special_token(): + train_text = """ + [CLS] Lorem ipsum dolor sit amet, consectetur adipiscing elit, + sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris + nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in + reprehenderit in voluptate velit [MASK] esse cillum dolore eu fugiat nulla + pariatur. Excepteur sint occaecat cupidatat non proident, sunt in + culpa qui officia deserunt mollit anim id est laborum. + """ + test_text = "[CLS] Lorem ipsum [TOKEN] dolor sit [MASK] amet" + TRAIN_DATA_PATH = "train_data.txt" + MODEL_PATH = "model.yttm" + with open(TRAIN_DATA_PATH, "w") as fin: + fin.write(train_text) + model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100, custom_tokens=[b'[CLS]',b'[MASK]',b'']) + tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD) + expected_result = ['▁','[CLS]', '▁', 'L', 'or', 'e', 'm', '▁', 'ip', 's', 'um', '▁', '[TOKEN]', '▁dolor', '▁', '', '▁s', 'it', '▁[', 'M', 'A', 'S', 'K', ']', '▁a', 'm', 'e', 't'] + print(tokenized_text) + assert tokenized_text == expected_result + print(tokenized_text) + os.remove(TRAIN_DATA_PATH) diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 768a817..a74af24 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -10,15 +10,20 @@ using std::string; using std::vector; void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; + fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << " "; + for (auto token: custom_tokens) fout << token << " "; + fout << std::endl; + } void SpecialTokens::load(std::ifstream &fin) { fin >> unk_id >> pad_id >> bos_id >> eos_id; + std::string token; + while (fin >> token) + custom_tokens.push_back(token); } -uint32_t SpecialTokens::max_id() const { +uint32_t SpecialTokens::max_predefined_id() const { int ret = 0; ret = std::max(ret, unk_id); ret = std::max(ret, pad_id); @@ -27,8 +32,14 @@ uint32_t SpecialTokens::max_id() const { return ret; } +uint32_t SpecialTokens::max_id() const { + int ret = max_predefined_id(); + ret += custom_tokens.size(); + return ret; +} + bool SpecialTokens::taken_id(int id) const { - return id == unk_id || id == pad_id || id == bos_id || id == eos_id; + return id == unk_id || id == pad_id || id == bos_id || id == eos_id || (id > max_predefined_id() && id <= max_id()); } uint64_t SpecialTokens::n_special_tokens() const { @@ -37,6 +48,7 @@ uint64_t SpecialTokens::n_special_tokens() const { cnt += (pad_id != -1); cnt += (bos_id != -1); cnt += (eos_id != -1); + cnt += custom_tokens.size(); return cnt; } diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index ce802d5..9dd30c9 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -26,6 +26,7 @@ struct SpecialTokens { int unk_id = -1; int bos_id = -1; int eos_id = -1; + std::vector custom_tokens; SpecialTokens() = default; @@ -40,6 +41,8 @@ struct SpecialTokens { bool taken_id(int id) const; uint64_t n_special_tokens() const; + private: + uint32_t max_predefined_id() const; }; struct BpeConfig { diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 1d7774d..571f13a 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -14,6 +14,7 @@ cdef extern from "bpe.h" namespace "vkcom": int unk_id int bos_id int eos_id + vector[string] custom_tokens cdef cppclass BpeConfig: double character_coverage @@ -67,6 +68,7 @@ cdef class BPE: vocab_size, coverage=1.0, n_threads=-1, + custom_tokens=[], pad_id=0, unk_id=1, bos_id=2, @@ -79,6 +81,7 @@ cdef class BPE: bpe_config.special_tokens.unk_id = unk_id bpe_config.special_tokens.bos_id = bos_id bpe_config.special_tokens.eos_id = eos_id + bpe_config.special_tokens.custom_tokens = custom_tokens cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config) if status.code != 0: diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index 593febf..8cd1eb9 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -22,6 +22,7 @@ def train( data: str, model: str, vocab_size: int, + custom_tokens: List[bytes] = [], coverage: float = 1.0, n_threads: int = -1, pad_id: int = 0, @@ -35,6 +36,7 @@ def train( vocab_size=vocab_size, n_threads=n_threads, coverage=coverage, + custom_tokens=custom_tokens, pad_id=pad_id, unk_id=unk_id, bos_id=bos_id, From 7601884e57a15f5f84173c9f3d98543245587c0f Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 08:52:43 +0800 Subject: [PATCH 2/7] finished encode bpe with custom token --- tests/unit_tests/test_manual.py | 4 ++-- youtokentome/cpp/bpe.cpp | 18 ++++++++++++++++++ youtokentome/cpp/bpe.h | 1 + youtokentome/cpp/utils.h | 1 - 4 files changed, 21 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/test_manual.py index 607fbd2..29a2851 100644 --- a/tests/unit_tests/test_manual.py +++ b/tests/unit_tests/test_manual.py @@ -89,9 +89,9 @@ def test_special_token(): MODEL_PATH = "model.yttm" with open(TRAIN_DATA_PATH, "w") as fin: fin.write(train_text) - model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100, custom_tokens=[b'[CLS]',b'[MASK]',b'']) + model = yttm.BPE.train(TRAIN_DATA_PATH, MODEL_PATH, 100, custom_tokens=[b'[CLS]',b'[TOKEN]',b'']) tokenized_text = model.encode([test_text], output_type=yttm.OutputType.SUBWORD) - expected_result = ['▁','[CLS]', '▁', 'L', 'or', 'e', 'm', '▁', 'ip', 's', 'um', '▁', '[TOKEN]', '▁dolor', '▁', '', '▁s', 'it', '▁[', 'M', 'A', 'S', 'K', ']', '▁a', 'm', 'e', 't'] + expected_result = [['▁','[CLS]', '▁', 'L', 'or', 'e', 'm', '▁', 'ip', 's', 'um', '▁', '[TOKEN]', '▁dolor', '▁', '', '▁s', 'it', '▁', '[', 'M', 'A', 'S', 'K', ']', '▁a', 'm', 'e', 't']] print(tokenized_text) assert tokenized_text == expected_result print(tokenized_text) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index c28ee8a..ecd7b91 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1663,6 +1663,19 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, auto end_of_word = std::find_if(begin_of_word, text.end(), is_space); it_text = end_of_word; + auto word = encode_utf8({begin_of_word, end_of_word}); + + if (custom_token2id.count(word)) { + if (output_type == ID) { + output_ids.push_back(bpe_state.char2id.at(SPACE_TOKEN)); + output_ids.push_back(custom_token2id.find(word) -> second); + } else { + output_pieces.push_back(encode_utf8({SPACE_TOKEN})); + output_pieces.push_back(word); + } + continue; + } + uint32_t new_token_cur = new_tokens_start; list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0); @@ -1840,6 +1853,11 @@ void BaseEncoder::fill_from_state() { } reversed_recipe[BOS_TOKEN] = bpe_state.special_tokens.bos_id; reversed_recipe[EOS_TOKEN] = bpe_state.special_tokens.eos_id; + uint32_t custom_id = bpe_state.special_tokens.max_predefined_id(); + for (auto token : bpe_state.special_tokens.custom_tokens) { + ++custom_id; + custom_token2id[token] = custom_id; + } } int BaseEncoder::vocab_size() const { diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index 99464a2..cac5063 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -27,6 +27,7 @@ class BaseEncoder { flat_hash_map id2char; flat_hash_map> recipe; flat_hash_map reversed_recipe; + flat_hash_map custom_token2id; flat_hash_map rule2id; int n_threads; diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index 9dd30c9..4a5102c 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -41,7 +41,6 @@ struct SpecialTokens { bool taken_id(int id) const; uint64_t n_special_tokens() const; - private: uint32_t max_predefined_id() const; }; From b53046e11feefa01c105f01f357b98cfb3f9b63a Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 18:19:49 +0800 Subject: [PATCH 3/7] support custom special token for encode --- youtokentome/cpp/bpe.cpp | 30 ++++++++++++++++-------------- youtokentome/cpp/utf8.h | 2 ++ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index ecd7b91..faf8155 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1663,21 +1663,9 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, auto end_of_word = std::find_if(begin_of_word, text.end(), is_space); it_text = end_of_word; - auto word = encode_utf8({begin_of_word, end_of_word}); - - if (custom_token2id.count(word)) { - if (output_type == ID) { - output_ids.push_back(bpe_state.char2id.at(SPACE_TOKEN)); - output_ids.push_back(custom_token2id.find(word) -> second); - } else { - output_pieces.push_back(encode_utf8({SPACE_TOKEN})); - output_pieces.push_back(word); - } - continue; - } - uint32_t new_token_cur = new_tokens_start; list.emplace_back(bpe_state.char2id.at(SPACE_TOKEN), 0); + string utf8_text; for (auto it_char_in_word = begin_of_word; it_char_in_word < end_of_word;) { if (bpe_state.char2id.count(*it_char_in_word) == 0) { @@ -1685,17 +1673,31 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, it_char_in_word, end_of_word, [&](uint32_t ch) { return bpe_state.char2id.count(ch); }); - unrecognized_tokens[new_token_cur] = + unrecognized_tokens[new_token_cur] = utf8_text = encode_utf8({it_char_in_word, it_unrecognized_word}); it_char_in_word = it_unrecognized_word; list.emplace_back(new_token_cur, list.size()); new_token_cur++; } else { + if (custom_token2id.size()) + utf8_to_chars(*it_char_in_word, std::back_inserter(utf8_text)); list.emplace_back(bpe_state.char2id.at(*it_char_in_word), list.size()); ++it_char_in_word; } } + + if (custom_token2id.size() && custom_token2id.count(utf8_text)) { + if (output_type == ID) { + output_ids.push_back(bpe_state.char2id.at(SPACE_TOKEN)); + output_ids.push_back(custom_token2id.find(utf8_text) -> second); + } else { + output_pieces.push_back(encode_utf8({SPACE_TOKEN})); + output_pieces.push_back(utf8_text); + } + continue; + } + list.back().next = -1; diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index ec34831..d51cb66 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -8,6 +8,8 @@ constexpr static uint32_t INVALID_UNICODE = 0x0fffffff; uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); +void utf8_to_chars(const uint32_t x, const std::back_insert_iterator it); + std::string encode_utf8(const std::vector &utext); std::vector decode_utf8(const char *begin, const char *end); From 2197896bdd761d16cee722a57b9911d81320a41b Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 18:45:34 +0800 Subject: [PATCH 4/7] fix vocab cli --- youtokentome/cpp/bpe.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index faf8155..021d5c1 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1967,6 +1967,10 @@ Status BaseEncoder::id_to_subword(int id, string *subword, bool replace_space) c *subword = EOS_TOKEN; return Status(); } + if (id <= bpe_state.special_tokens.max_id() && id > bpe_state.special_tokens.max_predefined_id()) { + *subword = bpe_state.special_tokens.custom_tokens[id - bpe_state.special_tokens.max_predefined_id() - 1]; + return Status(); + } assert(recipe.count(id)); if (replace_space) { From 20c848b48ed99ed7479d648172b82a339a4b04eb Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 19:21:42 +0800 Subject: [PATCH 5/7] fix missed encoding custom token in certain circumstances --- youtokentome/cpp/bpe.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 021d5c1..6326aef 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1673,8 +1673,10 @@ DecodeResult BaseEncoder::encode_sentence(const std::string &sentence_utf8, it_char_in_word, end_of_word, [&](uint32_t ch) { return bpe_state.char2id.count(ch); }); - unrecognized_tokens[new_token_cur] = utf8_text = + unrecognized_tokens[new_token_cur] = encode_utf8({it_char_in_word, it_unrecognized_word}); + if (custom_token2id.size()) + utf8_text.append(unrecognized_tokens[new_token_cur]); it_char_in_word = it_unrecognized_word; list.emplace_back(new_token_cur, list.size()); From bd3ae7d5220685c9288b59e9d77d4a70c4f6cbee Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 20:11:29 +0800 Subject: [PATCH 6/7] add cli configs --- youtokentome/cpp/bpe.cpp | 6 ++++++ youtokentome/yttm_cli.py | 10 +++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 6326aef..f0ed155 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1515,6 +1515,12 @@ void print_config(const string &input_path, const string &model_path, std::cerr << " unk: " << bpe_config.special_tokens.unk_id << std::endl; std::cerr << " bos: " << bpe_config.special_tokens.bos_id << std::endl; std::cerr << " eos: " << bpe_config.special_tokens.eos_id << std::endl; + if (bpe_config.special_tokens.custom_tokens.size()) { + std::cerr << " custom_tokens: "; + for (auto token:bpe_config.special_tokens.custom_tokens) + std::cerr << token << " "; + std::cerr << std::endl; + } std::cerr << std::endl; } diff --git a/youtokentome/yttm_cli.py b/youtokentome/yttm_cli.py index 7e66879..318aea2 100644 --- a/youtokentome/yttm_cli.py +++ b/youtokentome/yttm_cli.py @@ -57,7 +57,14 @@ def main(): default=3, show_default=True, ) -def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): +@click.option( + "--custom_tokens", + type=click.STRING, + help="Tokens which will not be split into subwords, muiltple tokens should be provided with comma seperated.", + default="", + show_default=True, +) +def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id, custom_tokens): """Train BPE model.""" yttmc.BPE.train( data=data, @@ -69,6 +76,7 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo unk_id=unk_id, bos_id=bos_id, eos_id=eos_id, + custom_tokens=map(lambda t: t.encode("utf8"), custom_tokens.split(',')) ) From f4ac18d100e889a8e039ad44d8450b240393aa2c Mon Sep 17 00:00:00 2001 From: 9173860 Date: Thu, 15 Sep 2022 21:07:01 +0800 Subject: [PATCH 7/7] update README.md with custom tokens feature --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6b707c2..a7f3584 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,9 @@ test_text = "".join([random.choice("abcde ") for _ in range(100)]) # Training model yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path) +# Training model with custom tokens +yttm.BPE.train(data=train_data_path, vocab_size=5000, model=model_path, custom_tokens=[b"[CLS]", b"[MASK]"]) + # Loading model bpe = yttm.BPE(model=model_path) @@ -71,7 +74,7 @@ print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD))   ### Training model ```python -youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) +youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3, custom_tokens=[]) ``` Trains BPE model and saves to file. @@ -86,6 +89,7 @@ Trains BPE model and saves to file. * `unk_id`: int, reserved id for unknown symbols * `bos_id`: int, reserved id for begin of sentence token * `eos_id`: int, reserved id for end of sentence token +* `custom_tokens`: List[bytes], tokens which will not be split into subwords. **Returns**: Class `youtokentome.BPE` with the loaded model. @@ -191,7 +195,7 @@ Convert each id to subword and concatenate with space symbol. ### Example ```bash -$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 +$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 --custom_tokens "[CLS],[MASK]" $ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA ``` @@ -234,6 +238,9 @@ Options: --unk_id INTEGER Unknown token id. [default: 1] --bos_id INTEGER 'Begin of sentence' token id. [default: 2] --eos_id INTEGER 'End of sentence' token id. [default: 3] + --custom_tokens TEXT Tokens which will not be split into + subwords, muiltple tokens should be + provided with comma seperated. --help Show this message and exit. ```