Skip to content

Commit

Permalink
[V1][Frontend] Add Testing For V1 Runtime Parameters (vllm-project#14159
Browse files Browse the repository at this point in the history
)

Signed-off-by: [email protected] <[email protected]>
Signed-off-by: Johnny <[email protected]>
  • Loading branch information
robertgshaw2-redhat authored and johnnynunez committed Mar 6, 2025
1 parent d40520b commit cf9e135
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 17 deletions.
150 changes: 150 additions & 0 deletions tests/v1/sample/test_sampling_params_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
import os

import pytest

from vllm import LLM, SamplingParams

if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True)

MODEL = "meta-llama/Llama-3.2-1B"
PROMPT = "Hello my name is Robert and I"


@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(MODEL, enforce_eager=True)


def test_n_gt_1(model):
"""ParallelSampling is supported."""

params = SamplingParams(n=3)
outputs = model.generate(PROMPT, params)
assert len(outputs[0].outputs) == 3


def test_best_of(model):
"""Raise a ValueError since best_of is deprecated."""

params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError):
_ = model.generate(PROMPT, params)


def test_penalties(model):
"""Check that we do not get errors if applied."""

params = SamplingParams(
temperature=1.2,
presence_penalty=1.2,
frequency_penalty=1.2,
repetition_penalty=1.2,
min_p=0.5,
top_p=0.5,
top_k=3,
)
_ = model.generate(PROMPT, params)


def test_stop(model):
"""Check that we respect the stop words."""

output = model.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split()

STOP_IDX = 5
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
output = model.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split()

# Output should not contain the stop word.
assert len(new_split_text) == STOP_IDX

params = SamplingParams(temperature=0,
stop=split_text[STOP_IDX],
include_stop_str_in_output=True)
output = model.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split()

# Output should contain the stop word.
assert len(new_split_text) == STOP_IDX + 1


def test_stop_token_ids(model):
"""Check that we respect the stop token ids."""

output = model.generate(PROMPT, SamplingParams(temperature=0))

stop_token_id_0 = output[0].outputs[0].token_ids[5]
stop_token_id_1 = output[0].outputs[0].token_ids[6]

stop_token_ids = [stop_token_id_1, stop_token_id_0]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0

stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0


def test_bad_words(model):
"""Check that we respect bad words."""

with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))


def test_logits_processor(model):
"""Check that we reject logits processor."""

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

with pytest.raises(ValueError):
_ = model.generate(PROMPT,
SamplingParams(logits_processors=[pick_ith]))


def test_allowed_token_ids(model):
"""Check that we can use allowed_token_ids."""

TOKEN_ID = 10
allowed_token_ids = [TOKEN_ID]
output = model.generate(
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID

# Reject negative token id.
with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))

# Reject out of vocabulary.
with pytest.raises(ValueError):
_ = model.generate(PROMPT,
SamplingParams(allowed_token_ids=[10000000]))


def test_priority(model):
"""Check that we reject requests with priority."""

# Reject all allowed token ids
with pytest.raises(ValueError):
_ = model.generate(PROMPT, priority=[1])


def test_seed(model):
"""Check that seed impacts randomness."""

out_1 = model.generate(PROMPT, SamplingParams(seed=42))
out_2 = model.generate(PROMPT, SamplingParams(seed=42))
out_3 = model.generate(PROMPT, SamplingParams(seed=43))

assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text
63 changes: 46 additions & 17 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ def __init__(

def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return

max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
Expand All @@ -79,17 +76,10 @@ def _validate_logprobs(
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")

def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def _validate_allowed_token_ids(
def _validate_sampling_params(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
Expand All @@ -99,6 +89,42 @@ def _validate_allowed_token_ids(
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")

def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("VLLM V1 does not support per request "
"user provided logits processors.")

def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""

if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")

self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)

def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def process_inputs(
self,
request_id: str,
Expand All @@ -114,14 +140,17 @@ def process_inputs(
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.

self._validate_logprobs(params)
self._validate_lora(lora_request)
self._validate_allowed_token_ids(params)
self._validate_params(params)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")

if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."

# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def add_request(
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias

# FIXME: this implementation is incorrect. We create this mask
# then apply -inf to these specific tokens, which means we never
# select the allowed tokens! We cannot do the reverse, since
# this will impact the requests that do not have allowed_token_ids.
# This feature is currently disabled on V1 (we reject in Processor).
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
Expand Down

0 comments on commit cf9e135

Please sign in to comment.