forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Keyun Tong <[email protected]>
- Loading branch information
Showing
4 changed files
with
142 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import List, Union, Dict, Optional, TYPE_CHECKING, Any, Tuple | ||
import importlib | ||
if TYPE_CHECKING: | ||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam | ||
|
||
|
||
class TokenizerBase(ABC): | ||
|
||
@property | ||
@abstractmethod | ||
def all_special_tokens_extended(self) -> List[str]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def all_special_tokens(self) -> List[str]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def all_special_ids(self) -> List[int]: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def bos_token_id(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def eos_token_id(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def is_fast(self) -> bool: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def vocab_size(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@property | ||
@abstractmethod | ||
def max_token_id(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def __len__(self) -> int: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def __call__( | ||
self, | ||
prompt: Union[str, List[str], List[int]], | ||
add_special_tokens: bool = False, | ||
truncation: bool = False, | ||
max_length: Optional[int] = None, | ||
): | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def get_vocab(self) -> Dict[str, int]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def get_added_vocab(self) -> Dict[str, int]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def encode_one( | ||
self, | ||
prompt: str, | ||
truncation: bool = False, | ||
max_length: Optional[int] = None, | ||
) -> List[int]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def encode(self, prompt: str) -> List[int]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def apply_chat_template(self, | ||
messages: List["ChatCompletionMessageParam"], | ||
tools: Optional[Dict[str, Any]] = None, | ||
**kwargs) -> List[int]: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def convert_tokens_to_string(self, tokens: List[str]) -> str: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def decode(self, | ||
ids: Union[List[int], int], | ||
skip_special_tokens: bool = True) -> str: | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def convert_ids_to_tokens( | ||
self, | ||
ids: List[int], | ||
skip_special_tokens: bool = True, | ||
) -> List[str]: | ||
raise NotImplementedError() | ||
|
||
|
||
class TokenizerRegistry: | ||
# Tokenizer name -> (tokenizer module, tokenizer class) | ||
REGISTRY: Dict[str, Tuple[str, str]] = {} | ||
|
||
@staticmethod | ||
def register(name: str, module: str, class_name: str) -> TokenizerBase: | ||
TokenizerRegistry.REGISTRY[name] = (module, class_name) | ||
|
||
@staticmethod | ||
def get_tokenizer( | ||
tokenizer_name: str, | ||
*args, | ||
**kwargs, | ||
) -> TokenizerBase: | ||
tokenizer_cls = TokenizerRegistry.REGISTRY.get(tokenizer_name) | ||
if tokenizer_cls is None: | ||
raise ValueError(f"Tokenizer {tokenizer_name} not found.") | ||
|
||
tokenizer_module = importlib.import_module(tokenizer_cls[0]) | ||
class_ = getattr(tokenizer_module, tokenizer_cls[1]) | ||
return class_.from_pretrained(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters