diff --git a/python/xgrammar/tokenizer_info.py b/python/xgrammar/tokenizer_info.py index 8ec1e8d..854fa7f 100644 --- a/python/xgrammar/tokenizer_info.py +++ b/python/xgrammar/tokenizer_info.py @@ -4,6 +4,8 @@ from typing import List, Optional, Union from transformers import PreTrainedTokenizerBase, PreTrainedTokenizerFast +import tiktoken +import sentencepiece from .base import XGRObject, _core from .support import logging @@ -39,6 +41,35 @@ class VocabType(Enum): BYTE_FALLBACK = "BYTE_FALLBACK" BYTE_LEVEL = "BYTE_LEVEL" +def is_tiktoken_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool: + # helper to check if tokenizer is a tiktoken tokenizer + has_tiktoken_encoding = ( + hasattr(tokenizer, 'tokenizer') and + isinstance(tokenizer.tokenizer, tiktoken.Encoding) + ) + + filename_pattern = ( + "vocab_file" in tokenizer.vocab_files_names and + "tiktoken" in tokenizer.vocab_files_names["vocab_file"] + ) + + return has_tiktoken_encoding or filename_pattern + +def is_sentencepiece_tokenizer(tokenizer: PreTrainedTokenizerBase) -> bool: + # helper to check if tokenizer is a sentence piece tokenizer + has_sp_model_attr = ( + hasattr(tokenizer, 'sp_model') and + isinstance(tokenizer.sp_model, sentencepiece.SentencePieceProcessor) + ) + + has_nested_sp_model_attr = ( + hasattr(tokenizer, 'tokenizer') and + hasattr(tokenizer.tokenizer, 'sp_model') and + isinstance(tokenizer.tokenizer.sp_model, sentencepiece.SentencePieceProcessor) + ) + + return has_sp_model_attr or has_nested_sp_model_attr + class TokenizerInfo(XGRObject): """The tokenizer info contains the vocabulary, the type of the vocabulary, and necessary @@ -174,10 +205,7 @@ def from_huggingface( encoded_vocab, backend_str, vocab_size, stop_token_ids ) ) - elif ( - "vocab_file" in tokenizer.vocab_files_names - and "tiktoken" in tokenizer.vocab_files_names["vocab_file"] - ): + elif is_tiktoken_tokenizer(tokenizer): # tiktoken tokenizer # e.g. Phi-3-small-8k-instruct, Qwen-7B-Chat, stablelm-2-12b-chat (previously) if stop_token_ids is None: @@ -196,8 +224,65 @@ def from_huggingface( stop_token_ids=stop_token_ids, prepend_space_in_tokenization=False, ) + elif is_sentencepiece_tokenizer(tokenizer): + # sentencepiece tokenizer + # e.g. Chatglm3-6b + if hasattr(tokenizer, 'sp_model'): + sp_model = tokenizer.sp_model + elif hasattr(tokenizer, 'tokenizer') and hasattr(tokenizer.tokenizer, 'sp_model'): + sp_model = tokenizer.tokenizer.sp_model + + if stop_token_ids is None: + if hasattr(tokenizer, "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + else: + eos_id = sp_model.eos_id() + if eos_id != -1: + stop_token_ids = [eos_id] + else: + logger.warning( + "When constructing TokenizerInfo from a huggingface tokenizer, " + "stop_token_ids is neither provided by user nor found from the tokenizer. " + "It will be automatically detected." + ) + + if vocab_size is None: + vocab_size = len(tokenizer.get_vocab()) + + encoded_vocab = [''] * vocab_size + + # fill in base vocabulary from sp_model + for i in range(sp_model.get_piece_size()): + piece = sp_model.id_to_piece(i) + if i < vocab_size: + encoded_vocab[i] = piece + + vocab_dict = tokenizer.get_vocab() + + # fill in any special tokens from vocab_dict + for token, idx in vocab_dict.items(): + if idx < vocab_size: + encoded_vocab[idx] = token + + # detect vocab_type of tokenizer + byte_level_tokens = ["Ġ", "Ċ", "ĉ", "Ā"] + if any("<0x" in token for token in vocab_dict): + vocab_type = VocabType.BYTE_FALLBACK + elif any(stoken in token for stoken in byte_level_tokens + for token in vocab_dict): + vocab_type = VocabType.BYTE_LEVEL + else: + vocab_type = VocabType.RAW + + return TokenizerInfo( + encoded_vocab, + vocab_type=vocab_type, + vocab_size=vocab_size, + stop_token_ids=stop_token_ids, + prepend_space_in_tokenization=True, + ) else: - # TODO(yixin): sentencepiece tokenizer + # TODO(yixin): unsupported tokenizer raise ValueError(f"Unsupported tokenizer type: {type(tokenizer)}") @property diff --git a/tests/python/test_tokenizer_info.py b/tests/python/test_tokenizer_info.py index 142d7d6..9927f8e 100644 --- a/tests/python/test_tokenizer_info.py +++ b/tests/python/test_tokenizer_info.py @@ -41,6 +41,8 @@ def tokenizer_info_storage() -> Dict[str, Tuple[PreTrainedTokenizerBase, xgr.Tok ("Qwen/Qwen2.5-1.5B", xgr.VocabType.BYTE_LEVEL, False), ("internlm/internlm2_5-7b-chat", xgr.VocabType.BYTE_FALLBACK, False), ("mistralai/Mixtral-8x22B-Instruct-v0.1", xgr.VocabType.BYTE_FALLBACK, True), + ("THUDM/LongWriter-glm4-9b", xgr.VocabType.RAW, False), + ("THUDM/chatglm3-6b", xgr.VocabType.BYTE_FALLBACK, True), ] tokenizer_paths = [path for path, *_ in tokenizer_paths_metadata]