Skip to content

Commit

Permalink
Support tokenizer registry
Browse files Browse the repository at this point in the history
Signed-off-by: Keyun Tong <[email protected]>
  • Loading branch information
youngkent committed Jan 28, 2025
1 parent 6116ca8 commit d7e4688
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 3 deletions.
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,10 @@ def _init_has_inner_state(self) -> bool:

def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral"]:
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow' or 'mistral'.")
"either 'auto', 'slow', 'mistral' or 'custom'.")
self.tokenizer_mode = tokenizer_mode

def _get_preferred_task(
Expand Down
7 changes: 7 additions & 0 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import make_async
from vllm.transformers_utils.tokenizer_base import TokenizerBase, TokenizerRegistry

Check failure on line 18 in vllm/transformers_utils/tokenizer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F401)

vllm/transformers_utils/tokenizer.py:18:52: F401 `vllm.transformers_utils.tokenizer_base.TokenizerBase` imported but unused

Check failure on line 18 in vllm/transformers_utils/tokenizer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/transformers_utils/tokenizer.py:18:81: E501 Line too long (83 > 80)

logger = init_logger(__name__)

Expand Down Expand Up @@ -184,6 +185,12 @@ def get_tokenizer(
if tokenizer_mode == "mistral":
tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name),
revision=revision)
elif tokenizer_mode == "custom":
tokenizer = TokenizerRegistry.get_tokenizer(str(tokenizer_name),
*args,
revision=revision,
download_dir=download_dir,
**kwargs)
else:
try:
tokenizer = AutoTokenizer.from_pretrained(
Expand Down
131 changes: 131 additions & 0 deletions vllm/transformers_utils/tokenizer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from abc import ABC, abstractmethod
from typing import List, Union, Dict, Optional, TYPE_CHECKING, Any, Tuple
import importlib
if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


class TokenizerBase(ABC):

@property
@abstractmethod
def all_special_tokens_extended(self) -> List[str]:
raise NotImplementedError()

@property
@abstractmethod
def all_special_tokens(self) -> List[str]:
raise NotImplementedError()

@property
@abstractmethod
def all_special_ids(self) -> List[int]:
raise NotImplementedError()

@property
@abstractmethod
def bos_token_id(self) -> int:
raise NotImplementedError()

@property
@abstractmethod
def eos_token_id(self) -> int:
raise NotImplementedError()

@property
@abstractmethod
def is_fast(self) -> bool:
raise NotImplementedError()

@property
@abstractmethod
def vocab_size(self) -> int:
raise NotImplementedError()

@property
@abstractmethod
def max_token_id(self) -> int:
raise NotImplementedError()

@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()

@abstractmethod
def __call__(
self,
prompt: Union[str, List[str], List[int]],
add_special_tokens: bool = False,
truncation: bool = False,
max_length: Optional[int] = None,
):
raise NotImplementedError()

@abstractmethod
def get_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

@abstractmethod
def get_added_vocab(self) -> Dict[str, int]:
raise NotImplementedError()

@abstractmethod
def encode_one(
self,
prompt: str,
truncation: bool = False,
max_length: Optional[int] = None,
) -> List[int]:
raise NotImplementedError()

@abstractmethod
def encode(self, prompt: str) -> List[int]:
raise NotImplementedError()

@abstractmethod
def apply_chat_template(self,
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
raise NotImplementedError()

@abstractmethod
def convert_tokens_to_string(self, tokens: List[str]) -> str:
raise NotImplementedError()

@abstractmethod
def decode(self,
ids: Union[List[int], int],
skip_special_tokens: bool = True) -> str:
raise NotImplementedError()

@abstractmethod
def convert_ids_to_tokens(
self,
ids: List[int],
skip_special_tokens: bool = True,
) -> List[str]:
raise NotImplementedError()


class TokenizerRegistry:
# Tokenizer name -> (tokenizer module, tokenizer class)
REGISTRY: Dict[str, Tuple[str, str]] = {}

@staticmethod
def register(name: str, module: str, class_name: str) -> TokenizerBase:
TokenizerRegistry.REGISTRY[name] = (module, class_name)

@staticmethod
def get_tokenizer(
tokenizer_name: str,
*args,
**kwargs,
) -> TokenizerBase:
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name)
if tokenizer_cls is None:
raise ValueError(f"Tokenizer {tokenizer_name} not found.")

tokenizer_module = importlib.import_module(tokenizer_cls[0])
class_ = getattr(tokenizer_module, tokenizer_cls[1])
return class_.from_pretrained(*args, **kwargs)
3 changes: 2 additions & 1 deletion vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SentencePieceTokenizer)
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
Tekkenizer)
from vllm.transformers_utils.tokenizer_base import TokenizerBase

from vllm.logger import init_logger
from vllm.utils import is_list_of
Expand Down Expand Up @@ -104,7 +105,7 @@ def find_tokenizer_file(files: List[str]):
return matched_files[0]


class MistralTokenizer:
class MistralTokenizer(TokenizerBase):

def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
self.mistral = tokenizer
Expand Down

0 comments on commit d7e4688

Please sign in to comment.