Skip to content

Commit

Permalink
update tokenizer training script
Browse files Browse the repository at this point in the history
Signed-off-by: Max Fisher <[email protected]>
  • Loading branch information
maxfisher-g committed Aug 14, 2023
1 parent 334d025 commit 48603a3
Showing 1 changed file with 21 additions and 11 deletions.
32 changes: 21 additions & 11 deletions tokenizer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,56 @@ class LogLevel (IntEnum):
INFO = 0


whitespace_symbols = [ "\n", " ", "\t", "\r", "\v", "\f" ]


def train_model(input_filenames: Iterable[str], verbose: bool) -> io.BytesIO:
# max input size (filesize) is 10k bytes
# TODO consider splitting up long files into multiple sections
max_sentence_size = 10000

vocab_size = 48000

# aim to tokenize 99% of input characters
character_coverage = 0.99

min_log_level = LogLevel.INFO if verbose else LogLevel.WARNING

def filename_to_sentence(filename: str) -> bytes:
return Path(filename).read_bytes()
return Path(filename).read_text(errors="ignore")

min_log_level = LogLevel.INFO if verbose else LogLevel.WARNING
model = io.BytesIO()

spm.SentencePieceTrainer.Train(
sentence_iterator=map(filename_to_sentence, input_filenames),
model_writer=model,
vocab_size=16000,
vocab_size=vocab_size,
model_type="bpe",
max_sentence_length=max_sentence_size,
character_coverage=character_coverage,
minloglevel=min_log_level,
user_defined_symbols=["\n"], # include newline in vocabulary
normalization_rule_name="identity", # don't replace unicode chars with equivalent ones
remove_extra_whitespaces=0,
allow_whitespace_only_pieces=1,
split_by_whitespace=0, # allow whitespace within tokens, e.g. ") {"
remove_extra_whitespaces=False,
allow_whitespace_only_pieces=True,
split_by_whitespace=False, # allow whitespace within tokens, e.g. ") {"
byte_fallback=True,
)

return model


def test_model(model, test_stringe):
def test_model(model_filename, test_string):
sp = spm.SentencePieceProcessor()

#with open(model_filename, "rb") as modelfile:
# sp.LoadFromSerializedProto(modelfile.read())
with open(model_filename, "rb") as modelfile:
sp.LoadFromSerializedProto(modelfile.read())

sp.LoadFromSerializedProto(model.getvalue())
#sp.LoadFromSerializedProto(model.getvalue())

print("\n== Encoding test ==\n")
print(test_string)
print("\nencodes to\n")
print(sp.Encode(test_string, out_type=int))
print(sp.Encode(test_string, out_type=str))

Expand Down Expand Up @@ -101,7 +111,7 @@ def main():
model_file.write(model.getvalue())

if test_string:
test_model(model, test_string)
test_model(model_output_path, test_string)


if __name__ == "__main__":
Expand Down

0 comments on commit 48603a3

Please sign in to comment.