diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 0c892384236bc..90eb052399bf0 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1275,11 +1275,12 @@ def main(args: argparse.Namespace): '--tokenizer-mode', type=str, default="auto", - choices=['auto', 'slow', 'mistral'], + choices=['auto', 'slow', 'mistral', 'custom'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' - '"mistral" will always use the `mistral_common` tokenizer.') + '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"custom" will use --tokenizer to select the preregistered tokenizer.') parser.add_argument("--served-model-name", type=str, diff --git a/tests/tokenization/test_tokenizer_registry.py b/tests/tokenization/test_tokenizer_registry.py new file mode 100644 index 0000000000000..793d38f9c3666 --- /dev/null +++ b/tests/tokenization/test_tokenizer_registry.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer_base import (TokenizerBase, + TokenizerRegistry) + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +class TestTokenizer(TokenizerBase): + + @classmethod + def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer": + return TestTokenizer() + + @property + def all_special_tokens_extended(self) -> List[str]: + raise NotImplementedError() + + @property + def all_special_tokens(self) -> List[str]: + raise NotImplementedError() + + @property + def all_special_ids(self) -> List[int]: + raise NotImplementedError() + + @property + def bos_token_id(self) -> int: + return 0 + + @property + def eos_token_id(self) -> int: + return 1 + + @property + def sep_token(self) -> str: + raise NotImplementedError() + + @property + def pad_token(self) -> str: + raise NotImplementedError() + + @property + def is_fast(self) -> bool: + raise NotImplementedError() + + @property + def vocab_size(self) -> int: + raise NotImplementedError() + + @property + def max_token_id(self) -> int: + raise NotImplementedError() + + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + raise NotImplementedError() + + def get_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + def get_added_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + raise NotImplementedError() + + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + raise NotImplementedError() + + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + raise NotImplementedError() + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + raise NotImplementedError() + + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + raise NotImplementedError() + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + raise NotImplementedError() + + +def test_customized_tokenizer(): + TokenizerRegistry.register("test_tokenizer", + "tests.tokenization.test_tokenizer_registry", + "TestTokenizer") + + tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") + assert isinstance(tokenizer, TestTokenizer) + assert tokenizer.bos_token_id == 0 + assert tokenizer.eos_token_id == 1 + + tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom") + assert isinstance(tokenizer, TestTokenizer) + assert tokenizer.bos_token_id == 0 + assert tokenizer.eos_token_id == 1 diff --git a/vllm/config.py b/vllm/config.py index 426ba38080270..cb246bd2744de 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -102,8 +102,9 @@ class ModelConfig: it; otherwise, you must specify explicitly which task to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, "slow" will always use the slow tokenizer, and - "mistral" will always use the tokenizer from `mistral_common`. + available, "slow" will always use the slow tokenizer, + "mistral" will always use the tokenizer from `mistral_common`, and + "custom" will use --tokenizer to select the preregistered tokenizer. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. allowed_local_media_path: Allowing API requests to read local images or @@ -467,10 +468,10 @@ def _init_has_inner_state(self) -> bool: def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() - if tokenizer_mode not in ["auto", "slow", "mistral"]: + if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto', 'slow' or 'mistral'.") + "either 'auto', 'slow', 'mistral' or 'custom'.") self.tokenizer_mode = tokenizer_mode def _get_preferred_task( diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 40c6fb4567993..9b6564398f119 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -281,11 +281,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--tokenizer-mode', type=str, default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow', 'mistral'], + choices=['auto', 'slow', 'mistral', 'custom'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' - '"mistral" will always use the `mistral_common` tokenizer.') + '"mistral" will always use the `mistral_common` tokenizer. \n* ' + '"custom" will use --tokenizer to select the ' + 'preregistered tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d071a0b3cfc5d..73593f0c6f0a5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1051,9 +1051,9 @@ def _embedding_score( def _cross_encoding_score( self, - tokenizer: Union[AnyTokenizer], - text_1: List[Union[str, TextPrompt, TokensPrompt]], - text_2: List[Union[str, TextPrompt, TokensPrompt]], + tokenizer: AnyTokenizer, + text_1: List[str], + text_2: List[str], truncate_prompt_tokens: Optional[int] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, @@ -1176,29 +1176,36 @@ def ensure_str(prompt: SingletonPrompt): if isinstance(text_1, (str, dict)): # Convert a single prompt to a list. text_1 = [text_1] - text_1 = [ensure_str(t) for t in text_1] + input_text_1: List[str] = [ensure_str(t) for t in text_1] if isinstance(text_2, (str, dict)): # Convert a single prompt to a list. text_2 = [text_2] - text_2 = [ensure_str(t) for t in text_2] + input_text_2: List[str] = [ensure_str(t) for t in text_2] - if len(text_1) > 1 and len(text_1) != len(text_2): + if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2): raise ValueError("Input lengths must be either 1:1, 1:N or N:N") - if len(text_1) == 0: + if len(input_text_1) == 0: raise ValueError("At least one text element must be given") - if len(text_2) == 0: + if len(input_text_2) == 0: raise ValueError("At least one text_pair element must be given") if self.llm_engine.model_config.is_cross_encoder: - return self._cross_encoding_score(tokenizer, text_1, text_2, + return self._cross_encoding_score(tokenizer, input_text_1, + input_text_2, truncate_prompt_tokens, use_tqdm, lora_request, prompt_adapter_request) else: - return self._embedding_score(tokenizer, text_1, text_2, - truncate_prompt_tokens, use_tqdm, - lora_request, prompt_adapter_request) + + return self._embedding_score( + tokenizer, + input_text_1, # type: ignore[arg-type] + input_text_2, # type: ignore[arg-type] + truncate_prompt_tokens, + use_tqdm, + lora_request, + prompt_adapter_request) def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d39fdcb74833..9efb5e6fa3987 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -400,8 +400,7 @@ async def _preprocess_chat( _chat_template_kwargs.update(chat_template_kwargs or {}) request_prompt: Union[str, List[int]] - is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) - if is_mistral_tokenizer: + if isinstance(tokenizer, MistralTokenizer): request_prompt = apply_mistral_chat_template( tokenizer, messages=messages, diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 832aa8516cc35..c7597808f7fe3 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -121,7 +121,7 @@ async def create_score( tokenize_async = make_async(tokenizer.__call__, executor=self._tokenizer_executor) - prompt_inputs = await tokenize_async(text=q, + prompt_inputs = await tokenize_async(q, text_pair=t, **tokenization_kwargs) diff --git a/vllm/logits_process.py b/vllm/logits_process.py index d02072e8f8189..a810be7bc7a85 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -31,7 +31,7 @@ def get_bad_words_logits_processors( if isinstance(tokenizer, MistralTokenizer): # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(prompt=prompt) + prompt_token_ids = tokenizer.encode(text=prompt) else: prompt_token_ids = tokenizer.encode(text=prompt, add_special_tokens=False) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 520870b563c9e..0c0f68ac123e2 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -14,6 +14,8 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_base import (TokenizerBase, + TokenizerRegistry) from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async @@ -21,7 +23,7 @@ logger = init_logger(__name__) AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - MistralTokenizer] + TokenizerBase] def decode_tokens( @@ -47,11 +49,7 @@ def encode_tokens( Backend-agnostic equivalent of HF's :code:`tokenizer.encode(text, add_special_tokens=...)`. """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, - bos=add_special_tokens, - eos=add_special_tokens) - elif add_special_tokens is not None: + if add_special_tokens is not None: return tokenizer.encode(text, add_special_tokens=add_special_tokens) return tokenizer.encode(text) @@ -183,9 +181,17 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) + + tokenizer: AnyTokenizer if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) + elif tokenizer_mode == "custom": + tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name), + *args, + revision=revision, + download_dir=download_dir, + **kwargs) else: try: tokenizer = AutoTokenizer.from_pretrained( diff --git a/vllm/transformers_utils/tokenizer_base.py b/vllm/transformers_utils/tokenizer_base.py new file mode 100644 index 0000000000000..bb5ddaf88b219 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_base.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +class TokenizerBase(ABC): + + @property + @abstractmethod + def all_special_tokens_extended(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_tokens(self) -> List[str]: + raise NotImplementedError() + + @property + @abstractmethod + def all_special_ids(self) -> List[int]: + raise NotImplementedError() + + @property + @abstractmethod + def bos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def eos_token_id(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def sep_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def pad_token(self) -> str: + raise NotImplementedError() + + @property + @abstractmethod + def is_fast(self) -> bool: + raise NotImplementedError() + + @property + @abstractmethod + def vocab_size(self) -> int: + raise NotImplementedError() + + @property + @abstractmethod + def max_token_id(self) -> int: + raise NotImplementedError() + + def __len__(self) -> int: + return self.vocab_size + + @abstractmethod + def __call__( + self, + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + raise NotImplementedError() + + @abstractmethod + def get_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def get_added_vocab(self) -> Dict[str, int]: + raise NotImplementedError() + + @abstractmethod + def encode_one( + self, + text: str, + truncation: bool = False, + max_length: Optional[int] = None, + ) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[List[Dict[str, Any]]] = None, + **kwargs) -> List[int]: + raise NotImplementedError() + + @abstractmethod + def convert_tokens_to_string(self, tokens: List[str]) -> str: + raise NotImplementedError() + + @abstractmethod + def decode(self, + ids: Union[List[int], int], + skip_special_tokens: bool = True) -> str: + raise NotImplementedError() + + @abstractmethod + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + raise NotImplementedError() + + +class TokenizerRegistry: + # Tokenizer name -> (tokenizer module, tokenizer class) + REGISTRY: Dict[str, Tuple[str, str]] = {} + + @staticmethod + def register(name: str, module: str, class_name: str) -> None: + TokenizerRegistry.REGISTRY[name] = (module, class_name) + + @staticmethod + def get_tokenizer( + tokenizer_name: str, + *args, + **kwargs, + ) -> TokenizerBase: + tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name) + if tokenizer_cls is None: + raise ValueError(f"Tokenizer {tokenizer_name} not found.") + + tokenizer_module = importlib.import_module(tokenizer_cls[0]) + class_ = getattr(tokenizer_module, tokenizer_cls[1]) + return class_.from_pretrained(*args, **kwargs) diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index f08923e7401f3..59131a9d7bfdd 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -10,6 +10,7 @@ from huggingface_hub import HfApi, hf_hub_download from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_base import TokenizerBase from vllm.utils import is_list_of if TYPE_CHECKING: @@ -140,7 +141,7 @@ def make_mistral_chat_completion_request( tools=tools) # type: ignore[type-var] -class MistralTokenizer: +class MistralTokenizer(TokenizerBase): def __init__(self, tokenizer: "PublicMistralTokenizer") -> None: self.mistral = tokenizer @@ -251,6 +252,14 @@ def bos_token_id(self) -> int: def eos_token_id(self) -> int: return self.tokenizer.eos_id + @property + def sep_token(self) -> str: + raise NotImplementedError() + + @property + def pad_token(self) -> str: + raise NotImplementedError() + @property def is_fast(self) -> bool: return True @@ -268,25 +277,26 @@ def __len__(self) -> int: def __call__( self, - prompt: Union[str, List[str], List[int]], + text: Union[str, List[str], List[int]], + text_pair: Optional[str] = None, add_special_tokens: bool = False, truncation: bool = False, max_length: Optional[int] = None, ): input_ids: Union[List[int], List[List[int]]] # For List[str], original prompt text - if is_list_of(prompt, str): + if is_list_of(text, str): input_ids_: List[List[int]] = [] - for p in prompt: + for p in text: each_input_ids = self.encode_one(p, truncation, max_length) input_ids_.append(each_input_ids) input_ids = input_ids_ # For List[int], apply chat template output, already tokens. - elif is_list_of(prompt, int): - input_ids = prompt + elif is_list_of(text, int): + input_ids = text # For str, single prompt text else: - input_ids = self.encode_one(prompt, truncation, max_length) + input_ids = self.encode_one(text, truncation, max_length) return Encoding(input_ids=input_ids) def get_vocab(self) -> Dict[str, int]: @@ -300,22 +310,29 @@ def get_added_vocab(self) -> Dict[str, int]: def encode_one( self, - prompt: str, + text: str, truncation: bool = False, max_length: Optional[int] = None, ) -> List[int]: # Mistral Tokenizers should not add special tokens - input_ids = self.encode(prompt) + input_ids = self.encode(text) if truncation: input_ids = input_ids[:max_length] return input_ids - def encode(self, prompt: str) -> List[int]: + def encode(self, + text: str, + add_special_tokens: Optional[bool] = None) -> List[int]: # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` - return self.tokenizer.encode(prompt, bos=True, eos=False) + if add_special_tokens is not None: + return self.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) + else: + return self.tokenizer.encode(text, bos=True, eos=False) def apply_chat_template(self, messages: List["ChatCompletionMessageParam"],