Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Frontend] Add Testing For V1 Runtime Parameters #14159

Merged
merged 10 commits into from
Mar 5, 2025
Merged
150 changes: 150 additions & 0 deletions tests/v1/sample/test_sampling_params_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
import os

import pytest

from vllm import LLM, SamplingParams

if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True)

MODEL = "meta-llama/Llama-3.2-1B"
PROMPT = "Hello my name is Robert and I"


@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(MODEL, enforce_eager=True)


def test_n_gt_1(model):
"""ParallelSampling is supported."""

params = SamplingParams(n=3)
outputs = model.generate(PROMPT, params)
assert len(outputs[0].outputs) == 3


def test_best_of(model):
"""Raise a ValueError since best_of is deprecated."""

params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError):
_ = model.generate(PROMPT, params)


def test_penalties(model):
"""Check that we do not get errors if applied."""

params = SamplingParams(
temperature=1.2,
presence_penalty=1.2,
frequency_penalty=1.2,
repetition_penalty=1.2,
min_p=0.5,
top_p=0.5,
top_k=3,
)
_ = model.generate(PROMPT, params)


def test_stop(model):
"""Check that we respect the stop words."""

output = model.generate(PROMPT, SamplingParams(temperature=0))
split_text = output[0].outputs[0].text.split()

STOP_IDX = 5
params = SamplingParams(temperature=0, stop=split_text[STOP_IDX])
output = model.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split()

# Output should not contain the stop word.
assert len(new_split_text) == STOP_IDX

params = SamplingParams(temperature=0,
stop=split_text[STOP_IDX],
include_stop_str_in_output=True)
output = model.generate(PROMPT, params)
new_split_text = output[0].outputs[0].text.split()

# Output should contain the stop word.
assert len(new_split_text) == STOP_IDX + 1


def test_stop_token_ids(model):
"""Check that we respect the stop token ids."""

output = model.generate(PROMPT, SamplingParams(temperature=0))

stop_token_id_0 = output[0].outputs[0].token_ids[5]
stop_token_id_1 = output[0].outputs[0].token_ids[6]

stop_token_ids = [stop_token_id_1, stop_token_id_0]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
output = model.generate(PROMPT, params)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0

stop_token_ids = [stop_token_id_0, stop_token_id_1]
params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids)
assert output[0].outputs[0].token_ids[-1] == stop_token_id_0


def test_bad_words(model):
"""Check that we respect bad words."""

with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))


def test_logits_processor(model):
"""Check that we reject logits processor."""

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

with pytest.raises(ValueError):
_ = model.generate(PROMPT,
SamplingParams(logits_processors=[pick_ith]))


def test_allowed_token_ids(model):
"""Check that we can use allowed_token_ids."""

TOKEN_ID = 10
allowed_token_ids = [TOKEN_ID]
output = model.generate(
PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids))
assert output[0].outputs[0].token_ids[-1] == TOKEN_ID

# Reject negative token id.
with pytest.raises(ValueError):
_ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1]))

# Reject out of vocabulary.
with pytest.raises(ValueError):
_ = model.generate(PROMPT,
SamplingParams(allowed_token_ids=[10000000]))


def test_priority(model):
"""Check that we reject requests with priority."""

# Reject all allowed token ids
with pytest.raises(ValueError):
_ = model.generate(PROMPT, priority=[1])


def test_seed(model):
"""Check that seed impacts randomness."""

out_1 = model.generate(PROMPT, SamplingParams(seed=42))
out_2 = model.generate(PROMPT, SamplingParams(seed=42))
out_3 = model.generate(PROMPT, SamplingParams(seed=43))

assert out_1[0].outputs[0].text == out_2[0].outputs[0].text
assert out_1[0].outputs[0].text != out_3[0].outputs[0].text
63 changes: 46 additions & 17 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ def __init__(

def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return

max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
Expand All @@ -79,17 +76,10 @@ def _validate_logprobs(
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")

def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def _validate_allowed_token_ids(
def _validate_sampling_params(
self,
params: Union[SamplingParams, PoolingParams],
params: SamplingParams,
) -> None:
if not isinstance(params, SamplingParams):
return
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
Expand All @@ -99,6 +89,42 @@ def _validate_allowed_token_ids(
raise ValueError(
"allowed_token_ids contains out-of-vocab token id!")

def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
# Logits processors not supported.
if params.logits_processors:
raise ValueError("VLLM V1 does not support per request "
"user provided logits processors.")

def _validate_params(
self,
params: Union[SamplingParams, PoolingParams],
):
"""
Validate supported SamplingParam.
Should raise ValueError if unsupported for API Server.
"""

if not isinstance(params, SamplingParams):
raise ValueError("V1 does not yet support Pooling models.")

self._validate_logprobs(params)
self._validate_sampling_params(params)
self._validate_supported_sampling_params(params)

def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")

def process_inputs(
self,
request_id: str,
Expand All @@ -114,14 +140,17 @@ def process_inputs(
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.

self._validate_logprobs(params)
self._validate_lora(lora_request)
self._validate_allowed_token_ids(params)
self._validate_params(params)
if priority != 0:
raise ValueError("V1 does not support priority yet.")
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need this here too?

        if prompt_adapter_request is not None:
            raise ValueError("V1 does not support prompt_adapter_request.")

if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."
assert trace_headers is None, "vLLM V1 does not support tracing yet."

# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
Expand Down
5 changes: 5 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ def add_request(
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = sampling_params.logit_bias

# FIXME: this implementation is incorrect. We create this mask
# then apply -inf to these specific tokens, which means we never
# select the allowed tokens! We cannot do the reverse, since
# this will impact the requests that do not have allowed_token_ids.
# This feature is currently disabled on V1 (we reject in Processor).
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
Expand Down