Skip to content

Commit baaedfd

Browse files
DarkLight1337WoosukKwonFeiDeng
authored
[mypy] Enable following imports for entrypoints (vllm-project#7248)
Co-authored-by: Woosuk Kwon <[email protected]> Co-authored-by: Fei <[email protected]>
1 parent 4506641 commit baaedfd

26 files changed

+480
-320
lines changed

.github/workflows/mypy.yaml

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ jobs:
3838
mypy vllm/core --follow-imports skip
3939
mypy vllm/distributed --follow-imports skip
4040
mypy vllm/engine --follow-imports skip
41-
mypy vllm/entrypoints --follow-imports skip
4241
mypy vllm/executor --follow-imports skip
4342
mypy vllm/lora --follow-imports skip
4443
mypy vllm/model_executor --follow-imports skip

docs/requirements-docs.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
66
msgspec
77

88
# packages to install to build the documentation
9-
pydantic
9+
pydantic >= 2.8
1010
-f https://download.pytorch.org/whl/cpu
1111
torch
1212
py-cpuinfo

format.sh

-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
102102
mypy vllm/core --follow-imports skip
103103
mypy vllm/distributed --follow-imports skip
104104
mypy vllm/engine --follow-imports skip
105-
mypy vllm/entrypoints --follow-imports skip
106105
mypy vllm/executor --follow-imports skip
107106
mypy vllm/lora --follow-imports skip
108107
mypy vllm/model_executor --follow-imports skip

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ files = [
5656
"vllm/*.py",
5757
"vllm/adapter_commons",
5858
"vllm/assets",
59+
"vllm/entrypoints",
5960
"vllm/inputs",
6061
"vllm/logging",
6162
"vllm/multimodal",

requirements-common.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fastapi
1111
aiohttp
1212
openai >= 1.0 # Ensure modern openai package (ensure types module present)
1313
uvicorn[standard]
14-
pydantic >= 2.0 # Required for OpenAI server.
14+
pydantic >= 2.8 # Required for OpenAI server.
1515
pillow # Required for image processing
1616
prometheus_client >= 0.18.0
1717
prometheus-fastapi-instrumentator >= 7.0.0

tests/entrypoints/openai/test_chat.py

+83-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# imports for guided decoding tests
22
import json
33
import re
4-
from typing import List
4+
from typing import Dict, List, Optional
55

66
import jsonschema
77
import openai # use the official client for correctness check
@@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
174174
assert message.content is not None and len(message.content) >= 0
175175

176176

177+
@pytest.mark.asyncio
178+
@pytest.mark.parametrize(
179+
"model_name, prompt_logprobs",
180+
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
181+
)
182+
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
183+
model_name: str,
184+
prompt_logprobs: Optional[int]):
185+
params: Dict = {
186+
"messages": [{
187+
"role": "system",
188+
"content": "You are a helpful assistant."
189+
}, {
190+
"role": "user",
191+
"content": "Who won the world series in 2020?"
192+
}, {
193+
"role":
194+
"assistant",
195+
"content":
196+
"The Los Angeles Dodgers won the World Series in 2020."
197+
}, {
198+
"role": "user",
199+
"content": "Where was it played?"
200+
}],
201+
"model":
202+
model_name
203+
}
204+
205+
if prompt_logprobs is not None:
206+
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
207+
208+
if prompt_logprobs is not None and prompt_logprobs < 0:
209+
with pytest.raises(BadRequestError):
210+
await client.chat.completions.create(**params)
211+
else:
212+
completion = await client.chat.completions.create(**params)
213+
if prompt_logprobs is not None:
214+
assert completion.prompt_logprobs is not None
215+
assert len(completion.prompt_logprobs) > 0
216+
else:
217+
assert completion.prompt_logprobs is None
218+
219+
220+
@pytest.mark.asyncio
221+
@pytest.mark.parametrize(
222+
"model_name",
223+
[MODEL_NAME],
224+
)
225+
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
226+
model_name: str):
227+
params: Dict = {
228+
"messages": [{
229+
"role": "system",
230+
"content": "You are a helpful assistant."
231+
}, {
232+
"role": "user",
233+
"content": "Who won the world series in 2020?"
234+
}, {
235+
"role":
236+
"assistant",
237+
"content":
238+
"The Los Angeles Dodgers won the World Series in 2020."
239+
}, {
240+
"role": "user",
241+
"content": "Where was it played?"
242+
}],
243+
"model":
244+
model_name,
245+
"extra_body": {
246+
"prompt_logprobs": 1
247+
}
248+
}
249+
250+
completion_1 = await client.chat.completions.create(**params)
251+
252+
params["extra_body"] = {"prompt_logprobs": 2}
253+
completion_2 = await client.chat.completions.create(**params)
254+
255+
assert len(completion_1.prompt_logprobs[3]) == 1
256+
assert len(completion_2.prompt_logprobs[3]) == 2
257+
258+
177259
@pytest.mark.asyncio
178260
@pytest.mark.parametrize(
179261
"model_name",

tests/entrypoints/openai/test_completion.py

+5-96
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
import shutil
55
from tempfile import TemporaryDirectory
6-
from typing import Dict, List
6+
from typing import Dict, List, Optional
77

88
import jsonschema
99
import openai # use the official client for correctness check
@@ -268,118 +268,27 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
268268
assert len(completion.choices[0].text) >= 0
269269

270270

271-
@pytest.mark.asyncio
272-
@pytest.mark.parametrize(
273-
"model_name, prompt_logprobs",
274-
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
275-
)
276-
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
277-
model_name: str, prompt_logprobs: int):
278-
params: Dict = {
279-
"messages": [{
280-
"role": "system",
281-
"content": "You are a helpful assistant."
282-
}, {
283-
"role": "user",
284-
"content": "Who won the world series in 2020?"
285-
}, {
286-
"role":
287-
"assistant",
288-
"content":
289-
"The Los Angeles Dodgers won the World Series in 2020."
290-
}, {
291-
"role": "user",
292-
"content": "Where was it played?"
293-
}],
294-
"model":
295-
model_name
296-
}
297-
298-
if prompt_logprobs is not None:
299-
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
300-
301-
if prompt_logprobs and prompt_logprobs < 0:
302-
with pytest.raises(BadRequestError) as err_info:
303-
await client.chat.completions.create(**params)
304-
expected_err_string = (
305-
"Error code: 400 - {'object': 'error', 'message': "
306-
"'Prompt_logprobs set to invalid negative value: -1',"
307-
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
308-
assert str(err_info.value) == expected_err_string
309-
else:
310-
completion = await client.chat.completions.create(**params)
311-
if prompt_logprobs and prompt_logprobs > 0:
312-
assert completion.prompt_logprobs is not None
313-
assert len(completion.prompt_logprobs) > 0
314-
else:
315-
assert completion.prompt_logprobs is None
316-
317-
318-
@pytest.mark.asyncio
319-
@pytest.mark.parametrize(
320-
"model_name",
321-
[MODEL_NAME],
322-
)
323-
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
324-
model_name: str):
325-
params: Dict = {
326-
"messages": [{
327-
"role": "system",
328-
"content": "You are a helpful assistant."
329-
}, {
330-
"role": "user",
331-
"content": "Who won the world series in 2020?"
332-
}, {
333-
"role":
334-
"assistant",
335-
"content":
336-
"The Los Angeles Dodgers won the World Series in 2020."
337-
}, {
338-
"role": "user",
339-
"content": "Where was it played?"
340-
}],
341-
"model":
342-
model_name,
343-
"extra_body": {
344-
"prompt_logprobs": 1
345-
}
346-
}
347-
348-
completion_1 = await client.chat.completions.create(**params)
349-
350-
params["extra_body"] = {"prompt_logprobs": 2}
351-
completion_2 = await client.chat.completions.create(**params)
352-
353-
assert len(completion_1.prompt_logprobs[3]) == 1
354-
assert len(completion_2.prompt_logprobs[3]) == 2
355-
356-
357271
@pytest.mark.asyncio
358272
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
359273
(MODEL_NAME, 0),
360274
(MODEL_NAME, 1),
361275
(MODEL_NAME, None)])
362276
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
363277
model_name: str,
364-
prompt_logprobs: int):
278+
prompt_logprobs: Optional[int]):
365279
params: Dict = {
366280
"prompt": ["A robot may not injure another robot", "My name is"],
367281
"model": model_name,
368282
}
369283
if prompt_logprobs is not None:
370284
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
371285

372-
if prompt_logprobs and prompt_logprobs < 0:
373-
with pytest.raises(BadRequestError) as err_info:
286+
if prompt_logprobs is not None and prompt_logprobs < 0:
287+
with pytest.raises(BadRequestError):
374288
await client.completions.create(**params)
375-
expected_err_string = (
376-
"Error code: 400 - {'object': 'error', 'message': "
377-
"'Prompt_logprobs set to invalid negative value: -1',"
378-
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
379-
assert str(err_info.value) == expected_err_string
380289
else:
381290
completion = await client.completions.create(**params)
382-
if prompt_logprobs and prompt_logprobs > 0:
291+
if prompt_logprobs is not None:
383292
assert completion.choices[0].prompt_logprobs is not None
384293
assert len(completion.choices[0].prompt_logprobs) > 0
385294

vllm/engine/async_llm_engine.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Optional, Set, Tuple, Type, Union)
77

88
import torch
9-
from transformers import PreTrainedTokenizer
109
from typing_extensions import assert_never
1110

1211
import vllm.envs as envs
@@ -31,6 +30,7 @@
3130
from vllm.sampling_params import SamplingParams
3231
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
3332
SequenceGroupMetadata)
33+
from vllm.transformers_utils.tokenizer import AnyTokenizer
3434
from vllm.usage.usage_lib import UsageContext
3535
from vllm.utils import print_warning_once
3636

@@ -427,8 +427,8 @@ async def _tokenize_prompt_async(
427427
lora_request: Optional[LoRARequest],
428428
) -> List[int]:
429429
"""Async version of :meth:`_tokenize_prompt`."""
430-
tokenizer = self.get_tokenizer_group("prompts must be None if "
431-
"skip_tokenizer_init is True")
430+
tokenizer = self.get_tokenizer_group(
431+
missing_msg="prompts must be None if skip_tokenizer_init is True")
432432

433433
return await tokenizer.encode_async(request_id=request_id,
434434
prompt=prompt,
@@ -771,7 +771,7 @@ def _error_callback(self, exc: Exception) -> None:
771771
async def get_tokenizer(
772772
self,
773773
lora_request: Optional[LoRARequest] = None,
774-
) -> "PreTrainedTokenizer":
774+
) -> AnyTokenizer:
775775
if self.engine_use_ray:
776776
return await self.engine.get_tokenizer.remote( # type: ignore
777777
lora_request)

vllm/engine/llm_engine.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
44
Mapping, Optional)
55
from typing import Sequence as GenericSequence
6-
from typing import Set, Tuple, Type, TypeVar, Union
6+
from typing import Set, Tuple, Type, Union
77

8-
from typing_extensions import assert_never
8+
from typing_extensions import TypeVar, assert_never
99

1010
import vllm.envs as envs
1111
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@@ -43,8 +43,9 @@
4343
init_tracer)
4444
from vllm.transformers_utils.config import try_get_generation_config
4545
from vllm.transformers_utils.detokenizer import Detokenizer
46+
from vllm.transformers_utils.tokenizer import AnyTokenizer
4647
from vllm.transformers_utils.tokenizer_group import (
47-
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
48+
BaseTokenizerGroup, init_tokenizer_from_configs)
4849
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
4950
usage_message)
5051
from vllm.utils import Counter, Device
@@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
6768
return config.to_diff_dict()
6869

6970

71+
_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
7072
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
7173

7274
PromptComponents = Tuple[Optional[str], List[int],
@@ -493,12 +495,21 @@ def __del__(self):
493495
"skip_tokenizer_init is True")
494496

495497
def get_tokenizer_group(
496-
self,
497-
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
498-
if self.tokenizer is None:
499-
raise ValueError(fail_msg)
498+
self,
499+
group_type: Type[_G] = BaseTokenizerGroup,
500+
*,
501+
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
502+
) -> _G:
503+
tokenizer_group = self.tokenizer
504+
505+
if tokenizer_group is None:
506+
raise ValueError(missing_msg)
507+
if not isinstance(tokenizer_group, group_type):
508+
raise TypeError("Invalid type of tokenizer group. "
509+
f"Expected type: {group_type}, but "
510+
f"found type: {type(tokenizer_group)}")
500511

501-
return self.tokenizer
512+
return tokenizer_group
502513

503514
def get_tokenizer(
504515
self,
@@ -693,8 +704,8 @@ def _tokenize_prompt(
693704
* prompt token ids
694705
'''
695706

696-
tokenizer = self.get_tokenizer_group("prompts must be None if "
697-
"skip_tokenizer_init is True")
707+
tokenizer = self.get_tokenizer_group(
708+
missing_msg="prompts must be None if skip_tokenizer_init is True")
698709

699710
return tokenizer.encode(request_id=request_id,
700711
prompt=prompt,

0 commit comments

Comments
 (0)