Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][vllm-API] Support tokenizer registry for customized tokenizer in vLLM #12518

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,11 +1224,12 @@ def main(args: argparse.Namespace):
'--tokenizer-mode',
type=str,
default="auto",
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n*'
'"custom" will use --tokenizer to select the preregistered tokenizer.')

parser.add_argument("--served-model-name",
type=str,
Expand Down
123 changes: 123 additions & 0 deletions tests/tokenization/test_tokenizer_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


class TestTokenizer(TokenizerBase):

@classmethod
def from_pretrained(cls, *args, **kwargs) -> "TestTokenizer":
return TestTokenizer()

@property
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()

@property
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()

@property
def all_special_ids(self) -> List[int]:
raise NotImplementedError()

@property
def bos_token_id(self) -> int:
return 0

@property
def eos_token_id(self) -> int:
return 1

@property
def sep_token(self) -> str:
raise NotImplementedError()

@property
def pad_token(self) -> str:
raise NotImplementedError()

@property
def is_fast(self) -> bool:
raise NotImplementedError()

@property
def vocab_size(self) -> int:
raise NotImplementedError()

@property
def max_token_id(self) -> int:
raise NotImplementedError()

def __call__(
self,
text: Union[str, List[str], List[int]],
text_pair: Optional[str] = None,
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()

def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

def encode_one(
self,
text: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()

def encode(self,
text: str,
add_special_tokens: Optional[bool] = None) -> List[int]:
raise NotImplementedError()

def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()

def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()

def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()

def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()


def test_customized_tokenizer():
TokenizerRegistry.register("test_tokenizer",
"tests.tokenization.test_tokenizer_registry",
"TestTokenizer")

tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1

tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
9 changes: 5 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ class ModelConfig:
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
"mistral" will always use the tokenizer from `mistral_common`.
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
Expand Down Expand Up @@ -453,10 +454,10 @@ def _init_has_inner_state(self) -> bool:

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.")
"either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode

def _get_preferred_task(
Expand Down
6 changes: 4 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral'],
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
Expand Down
21 changes: 11 additions & 10 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,9 +1051,9 @@ def _embedding_score(

def _cross_encoding_score(
self,
tokenizer: Union[AnyTokenizer],
text_1: List[Union[str, TextPrompt, TokensPrompt]],
text_2: List[Union[str, TextPrompt, TokensPrompt]],
tokenizer: AnyTokenizer,
text_1: List[str],
text_2: List[str],
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
truncate_prompt_tokens: Optional[int] = None,
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
Expand Down Expand Up @@ -1176,27 +1176,28 @@ def ensure_str(prompt: SingletonPrompt):
if isinstance(text_1, (str, dict)):
# Convert a single prompt to a list.
text_1 = [text_1]
text_1 = [ensure_str(t) for t in text_1]
input_text_1: List[str] = [ensure_str(t) for t in text_1]

if isinstance(text_2, (str, dict)):
# Convert a single prompt to a list.
text_2 = [text_2]
text_2 = [ensure_str(t) for t in text_2]
input_text_2: List[str] = [ensure_str(t) for t in text_2]

if len(text_1) > 1 and len(text_1) != len(text_2):
if len(input_text_1) > 1 and len(input_text_1) != len(input_text_2):
raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
if len(text_1) == 0:
if len(input_text_1) == 0:
raise ValueError("At least one text element must be given")
if len(text_2) == 0:
if len(input_text_2) == 0:
raise ValueError("At least one text_pair element must be given")

if self.llm_engine.model_config.is_cross_encoder:
return self._cross_encoding_score(tokenizer, text_1, text_2,
return self._cross_encoding_score(tokenizer, input_text_1,
input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request,
prompt_adapter_request)
else:
return self._embedding_score(tokenizer, text_1, text_2,
return self._embedding_score(tokenizer, input_text_1, input_text_2,
truncate_prompt_tokens, use_tqdm,
lora_request, prompt_adapter_request)

Expand Down
3 changes: 1 addition & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ async def _preprocess_chat(
_chat_template_kwargs.update(chat_template_kwargs or {})

request_prompt: Union[str, List[int]]
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)
if is_mistral_tokenizer:
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def create_score(

tokenize_async = make_async(tokenizer.__call__,
executor=self._tokenizer_executor)
prompt_inputs = await tokenize_async(text=q,
prompt_inputs = await tokenize_async(q,
text_pair=t,
**tokenization_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion vllm/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_bad_words_logits_processors(

if isinstance(tokenizer, MistralTokenizer):
# Mistral tokenizers should not add special tokens
prompt_token_ids = tokenizer.encode(prompt=prompt)
prompt_token_ids = tokenizer.encode(text=prompt)
else:
prompt_token_ids = tokenizer.encode(text=prompt,
add_special_tokens=False)
Expand Down
18 changes: 12 additions & 6 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_base import (TokenizerBase,
TokenizerRegistry)
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async

logger = init_logger(__name__)

AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast,
MistralTokenizer]
TokenizerBase]


def decode_tokens(
Expand All @@ -47,11 +49,7 @@ def encode_tokens(
Backend-agnostic equivalent of HF's
:code:`tokenizer.encode(text, add_special_tokens=...)`.
"""
if isinstance(tokenizer, MistralTokenizer):
return tokenizer.tokenizer.encode(text,
bos=add_special_tokens,
eos=add_special_tokens)
elif add_special_tokens is not None:
if add_special_tokens is not None:
return tokenizer.encode(text, add_special_tokens=add_special_tokens)
return tokenizer.encode(text)

Expand Down Expand Up @@ -183,9 +181,17 @@ def get_tokenizer(
'encoding and decoding.',
FutureWarning,
stacklevel=2)

tokenizer: AnyTokenizer
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
elif tokenizer_mode == "custom":
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
Loading
Loading