Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
arunvariyath committed Mar 28, 2024
2 parents 9a07022 + a68aebc commit c8dd25f
Show file tree
Hide file tree
Showing 9 changed files with 464 additions and 345 deletions.
160 changes: 0 additions & 160 deletions .gitignore

This file was deleted.

2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.

Copyright [yyyy] [name of copyright owner]
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
1 change: 1 addition & 0 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@
from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
114 changes: 39 additions & 75 deletions unsloth/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
__all__ = [
"get_chat_template",
"test_chat_templates",
"fix_sentencepiece_tokenizer",
]

from transformers import StoppingCriteria, StoppingCriteriaList
from torch import LongTensor, FloatTensor
from transformers.models.llama.modeling_llama import logger
from .models._utils import patch_tokenizer
from .save import patch_saving_functions
import os
import shutil
from .tokenizer_utils import (
load_correct_tokenizer,
fix_sentencepiece_tokenizer,
)
from .models._utils import patch_tokenizer

CHAT_TEMPLATES = {}

Expand Down Expand Up @@ -251,84 +255,23 @@
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token,)


def fix_sentencepiece_tokenizer(
old_tokenizer,
new_tokenizer,
token_mapping,
temporary_location = "_unsloth_sentencepiece_temp",
):
# From https://github.com/google/sentencepiece/issues/121
# We need to manually edit the sentencepiece tokenizer!
try:
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
except:
if not os.path.exists(temporary_location):
os.system("git clone https://github.com/google/sentencepiece.git unsloth_sentencepiece_temp")
os.system(f"cd {temporary_location}/src && protoc --python_out=. sentencepiece_model.proto")
shutil.rmtree(temporary_location)
pass
import sentencepiece.sentencepiece_model_pb2 as sentencepiece_model_pb2
pass

if not os.path.exists(temporary_location):
os.makedirs(temporary_location)
pass

# First save the old tokenizer
old_tokenizer.save_pretrained(temporary_location)

from sentencepiece import SentencePieceProcessor
tokenizer_file = sentencepiece_model_pb2.ModelProto()
tokenizer_file.ParseFromString(open(f"{temporary_location}/tokenizer.model", "rb").read())

# Now save the new tokenizer
new_tokenizer.save_pretrained(temporary_location)

# Now correct the old tokenizer's .model file
for old_token, new_token in token_mapping.items():
ids = old_tokenizer([old_token], add_special_tokens = False).input_ids
ids = ids[0]
if (len(ids) != 1):
# Skip this token!
print(f"Skip mapping {old_token} to {new_token} since {new_token} is already in the tokenizer!")
continue
pass
ids = ids[0]
tokenizer_piece = tokenizer_file.pieces[ids]
assert(tokenizer_piece.piece == old_token)
tokenizer_piece.piece = new_token
pass

# And now write it
with open(f"{temporary_location}/tokenizer.model", "wb") as file:
file.write(tokenizer_file.SerializeToString())
pass

# And load it!
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(temporary_location, eos_token = new_tokenizer.eos_token)
return tokenizer
pass


def get_chat_template(
tokenizer,
chat_template = "chatml",
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
map_eos_token = True,
):
assert(type(map_eos_token) is bool)
old_tokenizer = tokenizer

if map_eos_token is False:
assert("Unsloth: Can only map new tokens to EOS for now. Adding new tokens is not yet supported.")
pass

IS_GEMMA = False
if tokenizer.__class__.__name__.startswith("Gemma"):
if chat_template == "chatml": chat_template = "gemma_chatml"
IS_GEMMA = True
pass

# We first check if the tokenizer is a fast one. If not, we cannot convert this!
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
old_padding_side = tokenizer.padding_side

if type(chat_template) in (list, tuple,):
Expand All @@ -348,9 +291,17 @@ def get_chat_template(

assert(type(stop_word) is str)

# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)
if token_mapping is not None:
# Check fast tokenizer
if not is_fast_tokenizer:
logger.warning_once(
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
"Please log a Github issue if you want this as a new feature!\n"\
"Your chat template will still work, but it won't add or edit tokens."
)

elif token_mapping is not None:
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
# For Gemma :)

string_vocab = tokenizer._tokenizer.to_str()

Expand All @@ -368,22 +319,27 @@ def get_chat_template(
pass
pass

if not stop_word in token_mapping.values():
if map_eos_token and (not stop_word in token_mapping.values()):
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
pass

if skipped != len(token_mapping):
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)

if map_eos_token:
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
else:
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer)
pass

# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
else:
pass

elif stop_word != "eos_token":
elif map_eos_token and (stop_word != "eos_token"):
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")

# Replaces the old EOS token with a new one.
Expand All @@ -393,9 +349,14 @@ def get_chat_template(
# This is a HACK!
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
string_vocab = tokenizer._tokenizer.to_str()
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
old_eos_token = tokenizer.eos_token
string_vocab = string_vocab.replace(old_eos_token, stop_word)
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)
new_tokenizer = tokenizer.__class__(tokenizer_object = new_tokenizer, eos_token = stop_word)

# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
token_mapping = { old_eos_token : stop_word, }
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
pass

else:
Expand Down Expand Up @@ -433,7 +394,10 @@ def get_chat_template(
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token

#stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)

# Patch saving functions
tokenizer = patch_saving_functions(tokenizer)

return tokenizer#, stopping_criteria
pass
Expand Down
Loading

0 comments on commit c8dd25f

Please sign in to comment.