Skip to content

Commit

Permalink
sentence piece tokenizer support for TokenizerInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
zanderjiang committed Dec 5, 2024
1 parent 678f4c8 commit b1d60f9
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 5 deletions.
95 changes: 90 additions & 5 deletions python/xgrammar/tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from typing import List, Optional, Union

from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast
import tiktoken
import sentencepiece

from .base import XGRObject, _core
from .support import logging
Expand Down Expand Up @@ -39,6 +41,35 @@ class VocabType(Enum):
BYTE_FALLBACK = "BYTE_FALLBACK"
BYTE_LEVEL = "BYTE_LEVEL"

def is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
# helper to check if tokenizer is a tiktoken tokenizer
has_tiktoken_encoding = (
hasattr(tokenizer, 'tokenizer') and
isinstance(tokenizer.tokenizer, tiktoken.Encoding)
)

filename_pattern = (
"vocab_file" in tokenizer.vocab_files_names and
"tiktoken" in tokenizer.vocab_files_names["vocab_file"]
)

return has_tiktoken_encoding or filename_pattern

def is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool:
# helper to check if tokenizer is a sentence piece tokenizer
has_sp_model_attr = (
hasattr(tokenizer, 'sp_model') and
isinstance(tokenizer.sp_model, sentencepiece.SentencePieceProcessor)
)

has_nested_sp_model_attr = (
hasattr(tokenizer, 'tokenizer') and
hasattr(tokenizer.tokenizer, 'sp_model') and
isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor)
)

return has_sp_model_attr or has_nested_sp_model_attr


class TokenizerInfo(XGRObject):
"""The tokenizer info contains the vocabulary, the type of the vocabulary, and necessary
Expand Down Expand Up @@ -174,10 +205,7 @@ def from_huggingface(
encoded_vocab, backend_str, vocab_size, stop_token_ids
)
)
elif (
"vocab_file" in tokenizer.vocab_files_names
and "tiktoken" in tokenizer.vocab_files_names["vocab_file"]
):
elif is_tiktoken_tokenizer(tokenizer):
# tiktoken tokenizer
# e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously)
if stop_token_ids is None:
Expand All @@ -196,8 +224,65 @@ def from_huggingface(
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=False,
)
elif is_sentencepiece_tokenizer(tokenizer):
# sentencepiece tokenizer
# e.g. Chatglm3-6b
if hasattr(tokenizer, 'sp_model'):
sp_model = tokenizer.sp_model
elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'):
sp_model = tokenizer.tokenizer.sp_model

if stop_token_ids is None:
if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
else:
eos_id = sp_model.eos_id()
if eos_id != -1:
stop_token_ids = [eos_id]
else:
logger.warning(
"When constructing TokenizerInfo from a huggingface tokenizer, "
"stop_token_ids is neither provided by user nor found from the tokenizer. "
"It will be automatically detected."
)

if vocab_size is None:
vocab_size = len(tokenizer.get_vocab())

encoded_vocab = [''] * vocab_size

# fill in base vocabulary from sp_model
for i in range(sp_model.get_piece_size()):
piece = sp_model.id_to_piece(i)
if i < vocab_size:
encoded_vocab[i] = piece

vocab_dict = tokenizer.get_vocab()

# fill in any special tokens from vocab_dict
for token, idx in vocab_dict.items():
if idx < vocab_size:
encoded_vocab[idx] = token

# detect vocab_type of tokenizer
byte_level_tokens = ["Ġ", "Ċ", "ĉ", "Ā"]
if any("<0x" in token for token in vocab_dict):
vocab_type = VocabType.BYTE_FALLBACK
elif any(stoken in token for stoken in byte_level_tokens
for token in vocab_dict):
vocab_type = VocabType.BYTE_LEVEL
else:
vocab_type = VocabType.RAW

return TokenizerInfo(
encoded_vocab,
vocab_type=vocab_type,
vocab_size=vocab_size,
stop_token_ids=stop_token_ids,
prepend_space_in_tokenization=True,
)
else:
# TODO(yixin): sentencepiece tokenizer
# TODO(yixin): unsupported tokenizer
raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}")

@property
Expand Down
2 changes: 2 additions & 0 deletions tests/python/test_tokenizer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def tokenizer_info_storage() -> Dict[str, Tuple[PreTrainedTokenizerBase, xgr.Tok
("Qwen/Qwen2.5-1.5B", xgr.VocabType.BYTE_LEVEL, False),
("internlm/internlm2_5-7b-chat", xgr.VocabType.BYTE_FALLBACK, False),
("mistralai/Mixtral-8x22B-Instruct-v0.1", xgr.VocabType.BYTE_FALLBACK, True),
("THUDM/LongWriter-glm4-9b", xgr.VocabType.RAW, False),
("THUDM/chatglm3-6b", xgr.VocabType.BYTE_FALLBACK, True),
]

tokenizer_paths = [path for path, *_ in tokenizer_paths_metadata]
Expand Down

0 comments on commit b1d60f9

Please sign in to comment.