-
-
Notifications
You must be signed in to change notification settings - Fork 6.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Frontend] Add Testing For V1 Runtime Parameters #14159
Changes from 7 commits
9c7dbba
b09a2a3
612b6bf
ef6537c
1d02940
10033b9
f6eccb0
8b6452f
eb9e619
51341b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import os | ||
|
||
import pytest | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
if os.getenv("VLLM_USE_V1", "0") != "1": | ||
pytest.skip("Test package requires V1", allow_module_level=True) | ||
|
||
MODEL = "meta-llama/Llama-3.2-1B" | ||
PROMPT = "Hello my name is Robert and I" | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def model() -> LLM: | ||
return LLM(MODEL, enforce_eager=True) | ||
|
||
|
||
def test_n_gt_1(model): | ||
"""ParallelSampling is supported.""" | ||
|
||
params = SamplingParams(n=3) | ||
outputs = model.generate(PROMPT, params) | ||
assert len(outputs[0].outputs) == 3 | ||
|
||
|
||
def test_best_of(model): | ||
"""Raise a ValueError since best_of is deprecated.""" | ||
|
||
params = SamplingParams(n=2, best_of=3) | ||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, params) | ||
|
||
|
||
def test_penalties(model): | ||
"""Check that we do not get errors if applied.""" | ||
|
||
params = SamplingParams( | ||
temperature=1.2, | ||
presence_penalty=1.2, | ||
frequency_penalty=1.2, | ||
repetition_penalty=1.2, | ||
min_p=0.5, | ||
top_p=0.5, | ||
top_k=3, | ||
) | ||
_ = model.generate(PROMPT, params) | ||
|
||
|
||
def test_stop(model): | ||
"""Check that we respect the stop words.""" | ||
|
||
output = model.generate(PROMPT, SamplingParams(temperature=0)) | ||
split_text = output[0].outputs[0].text.split() | ||
|
||
STOP_IDX = 5 | ||
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX]) | ||
output = model.generate(PROMPT, params) | ||
new_split_text = output[0].outputs[0].text.split() | ||
|
||
# Output should not contain the stop word. | ||
assert len(new_split_text) == STOP_IDX | ||
|
||
params = SamplingParams(temperature=0, | ||
stop=split_text[STOP_IDX], | ||
include_stop_str_in_output=True) | ||
output = model.generate(PROMPT, params) | ||
new_split_text = output[0].outputs[0].text.split() | ||
|
||
# Output should contain the stop word. | ||
assert len(new_split_text) == STOP_IDX + 1 | ||
|
||
|
||
def test_stop_token_ids(model): | ||
"""Check that we respect the stop token ids.""" | ||
|
||
output = model.generate(PROMPT, SamplingParams(temperature=0)) | ||
|
||
stop_token_id_0 = output[0].outputs[0].token_ids[5] | ||
stop_token_id_1 = output[0].outputs[0].token_ids[6] | ||
|
||
stop_token_ids = [stop_token_id_1, stop_token_id_0] | ||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) | ||
output = model.generate(PROMPT, params) | ||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 | ||
|
||
stop_token_ids = [stop_token_id_0, stop_token_id_1] | ||
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) | ||
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 | ||
|
||
|
||
def test_bad_words(model): | ||
"""Check that we respect bad words.""" | ||
|
||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) | ||
|
||
|
||
def test_logits_processor(model): | ||
"""Check that we reject logits processor.""" | ||
|
||
# This sample logits processor gives infinite score to the i-th token, | ||
# where i is the length of the input sequence. | ||
# We therefore expect the output token sequence to be [0, 1, 2, ...] | ||
def pick_ith(token_ids, logits): | ||
logits[len(token_ids)] = float("inf") | ||
return logits | ||
|
||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, | ||
SamplingParams(logits_processors=[pick_ith])) | ||
|
||
|
||
def test_allowed_token_ids(model): | ||
"""Check that we can use allowed_token_ids.""" | ||
|
||
# This does not currently work due to incorrect implementation. | ||
# TOKEN_ID = 10 | ||
# allowed_token_ids = [TOKEN_ID] | ||
# output = model.generate(PROMPT, SamplingParams( | ||
# allowed_token_ids=allowed_token_ids)) | ||
# assert output[0].outputs[0].token_ids[-1] == TOKEN_ID | ||
|
||
# Reject all allowed token ids | ||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[10])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean an empty list here? or do you mean a full list? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. once your PR lands, I will revert this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see the point, you just want to disable this feature completely, and if there is anything, just raise ValueError. |
||
|
||
# Reject negative token id. | ||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) | ||
|
||
# Reject out of vocabulary. | ||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, | ||
SamplingParams(allowed_token_ids=[10000000])) | ||
|
||
|
||
def test_priority(model): | ||
"""Check that we reject requests with priority.""" | ||
|
||
# Reject all allowed token ids | ||
with pytest.raises(ValueError): | ||
_ = model.generate(PROMPT, priority=[1]) | ||
|
||
|
||
def test_seed(model): | ||
"""Check that seed impacts randomness.""" | ||
|
||
out_1 = model.generate(PROMPT, SamplingParams(seed=42)) | ||
out_2 = model.generate(PROMPT, SamplingParams(seed=42)) | ||
out_3 = model.generate(PROMPT, SamplingParams(seed=43)) | ||
|
||
assert out_1[0].outputs[0].text == out_2[0].outputs[0].text | ||
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,11 +55,8 @@ def __init__( | |
|
||
def _validate_logprobs( | ||
self, | ||
params: Union[SamplingParams, PoolingParams], | ||
params: SamplingParams, | ||
) -> None: | ||
if not isinstance(params, SamplingParams): | ||
return | ||
|
||
max_logprobs = self.model_config.max_logprobs | ||
# Validate sample logprobs. | ||
if params.logprobs and params.logprobs > max_logprobs: | ||
|
@@ -79,24 +76,57 @@ def _validate_logprobs( | |
raise ValueError("Prefix caching with prompt logprobs not yet " | ||
"supported on VLLM V1.") | ||
|
||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: | ||
if lora_request is not None and not self.lora_config: | ||
raise ValueError(f"Got lora_request {lora_request} but LoRA is " | ||
"not enabled!") | ||
|
||
def _validate_allowed_token_ids( | ||
def _validate_sampling_params( | ||
self, | ||
params: Union[SamplingParams, PoolingParams], | ||
params: SamplingParams, | ||
) -> None: | ||
if not isinstance(params, SamplingParams): | ||
return | ||
if params.allowed_token_ids is None: | ||
return | ||
if not all(0 <= tid < self.model_config.vocab_size | ||
for tid in params.allowed_token_ids): | ||
|
||
# Allowed token ids. | ||
if (params.allowed_token_ids is not None | ||
and not all(0 <= tid < self.model_config.get_vocab_size() | ||
for tid in params.allowed_token_ids)): | ||
raise ValueError( | ||
"allowed_token_ids contains out-of-vocab token id") | ||
|
||
def _validate_supported_sampling_params( | ||
self, | ||
params: SamplingParams, | ||
) -> None: | ||
# Best of not yet supported. | ||
if params.best_of: | ||
raise ValueError("VLLM V1 does not yet support best_of.") | ||
# Bad words not yet supported. | ||
if params.bad_words: | ||
raise ValueError("VLLM V1 does not yet support bad_words.") | ||
# Logits processors not supported. | ||
if params.logits_processors: | ||
raise ValueError("VLLM V1 does not yet support per request " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps remove the "yet" from this one? |
||
"logits processors.") | ||
# Allowed token ids is not supported. | ||
if params.allowed_token_ids: | ||
raise ValueError("VLLM V1 does not yet support allowed token ids.") | ||
|
||
def _validate_params( | ||
self, | ||
params: Union[SamplingParams, PoolingParams], | ||
): | ||
""" | ||
Validate supported SamplingParam. | ||
Should raise ValueError if unsupported for API Server. | ||
""" | ||
|
||
if not isinstance(params, SamplingParams): | ||
raise ValueError("V1 does not yet support Pooling models.") | ||
|
||
self._validate_logprobs(params) | ||
self._validate_sampling_params(params) | ||
self._validate_supported_sampling_params(params) | ||
|
||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: | ||
if lora_request is not None and not self.lora_config: | ||
raise ValueError(f"Got lora_request {lora_request} but LoRA is " | ||
"not enabled!") | ||
|
||
def process_inputs( | ||
self, | ||
request_id: str, | ||
|
@@ -112,14 +142,15 @@ def process_inputs( | |
# TODO(woosuk): Support pooling models. | ||
# TODO(woosuk): Support encoder-decoder models. | ||
|
||
self._validate_logprobs(params) | ||
self._validate_lora(lora_request) | ||
self._validate_allowed_token_ids(params) | ||
self._validate_params(params) | ||
if priority != 0: | ||
raise ValueError("V1 does not support priority yet.") | ||
if trace_headers is not None: | ||
raise ValueError("V1 does not support tracing yet.") | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need this here too? if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.") |
||
if arrival_time is None: | ||
arrival_time = time.time() | ||
assert priority == 0, "vLLM V1 does not support priority at the moment." | ||
assert trace_headers is None, "vLLM V1 does not support tracing yet." | ||
|
||
# Process inputs, which includes: | ||
# 1. Tokenize text prompt, with LoRA request if one exists. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be fixed by #14169.