From 9c7dbba19626cd85d690f4887143f7e2587e809f Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 22:19:14 +0000 Subject: [PATCH 1/9] updated Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/processor.py | 60 +++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 19 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 3a3fc69e53e44..9ce41a1968bc4 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -55,11 +55,8 @@ def __init__( def _validate_logprobs( self, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, ) -> None: - if not isinstance(params, SamplingParams): - return - max_logprobs = self.model_config.max_logprobs # Validate sample logprobs. if params.logprobs and params.logprobs > max_logprobs: @@ -79,24 +76,50 @@ def _validate_logprobs( raise ValueError("Prefix caching with prompt logprobs not yet " "supported on VLLM V1.") - def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: - if lora_request is not None and not self.lora_config: - raise ValueError(f"Got lora_request {lora_request} but LoRA is " - "not enabled!") - - def _validate_allowed_token_ids( + def _validate_sampling_params( self, - params: Union[SamplingParams, PoolingParams], + params: SamplingParams, ) -> None: - if not isinstance(params, SamplingParams): - return - if params.allowed_token_ids is None: - return - if not all(0 <= tid < self.model_config.vocab_size - for tid in params.allowed_token_ids): + + # Allowed token ids. + if (params.allowed_token_ids is not None + and not all(0 <= tid < self.model_config.vocab_size + for tid in params.allowed_token_ids)): raise ValueError( "allowed_token_ids contains out-of-vocab token id") + def _validate_supported_sampling_params( + self, + params: SamplingParams, + ) -> None: + # Best of. + if params.best_of is not None: + raise ValueError("VLLM V1 does not support best_of.") + # Bad words. + if params.bad_words is not None: + raise ValueError("VLLM V1 does not support bad_words.") + + def _validate_params( + self, + params: Union[SamplingParams, PoolingParams], + ): + """ + Validate supported SamplingParam. + Should raise ValueError if unsupported for API Server. + """ + + if not isinstance(params, SamplingParams): + raise ValueError("V1 does not yet support Pooling models.") + + self._validate_logprobs(params) + self._validate_sampling_params(params) + self._validate_supported_sampling_params(params) + + def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + def process_inputs( self, request_id: str, @@ -112,9 +135,8 @@ def process_inputs( # TODO(woosuk): Support pooling models. # TODO(woosuk): Support encoder-decoder models. - self._validate_logprobs(params) self._validate_lora(lora_request) - self._validate_allowed_token_ids(params) + self._validate_params(params) if arrival_time is None: arrival_time = time.time() From b09a2a306873add5b6e3567f98d96b33125a546c Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 22:41:02 +0000 Subject: [PATCH 2/9] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 97 +++++++++++++++++++++ vllm/v1/engine/processor.py | 4 +- 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 tests/v1/sample/test_sampling_params_e2e.py diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py new file mode 100644 index 0000000000000..8c042d3a2b899 --- /dev/null +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest + +from vllm import LLM, SamplingParams + +if os.getenv("VLLM_USE_V1", "0") != "1": + pytest.skip("Test package requires V1", allow_module_level=True) + +MODEL = "meta-llama/Llama-3.2-1B" +PROMPT = "Hello my name is Robert and I" + + +@pytest.fixture(scope="module") +def model() -> LLM: + return LLM(MODEL, enforce_eager=True) + + +def test_n_gt_1(model): + """ParallelSampling is supported.""" + + params = SamplingParams(n=3) + outputs = model.generate(PROMPT, params) + assert len(outputs[0].outputs) == 3 + + +def test_best_of(model): + """Raise a ValueError since best_of is deprecated.""" + + params = SamplingParams(n=2, best_of=3) + with pytest.raises(ValueError): + _ = model.generate(PROMPT, params) + + +def test_penalties(model): + """Check that we do not get errors if applied.""" + + params = SamplingParams( + temperature=1.2, + presence_penalty=1.2, + frequency_penalty=1.2, + repetition_penalty=1.2, + min_p=0.5, + top_p=0.5, + top_k=3, + ) + _ = model.generate(PROMPT, params) + + +def test_stop(model): + """Check that we respect the stop words.""" + + output = model.generate(PROMPT, SamplingParams(temperature=0)) + split_text = output[0].outputs[0].text.split() + + STOP_IDX = 5 + params = SamplingParams(temperature=0, stop=split_text[STOP_IDX]) + output = model.generate(PROMPT, params) + new_split_text = output[0].outputs[0].text.split() + + # Output should not contain the stop word. + assert len(new_split_text) == STOP_IDX + + params = SamplingParams(temperature=0, + stop=split_text[STOP_IDX], + include_stop_str_in_output=True) + output = model.generate(PROMPT, params) + new_split_text = output[0].outputs[0].text.split() + + # Output should contain the stop word. + assert len(new_split_text) == STOP_IDX + 1 + + +def test_stop_token_ids(model): + """Check that we respect the stop token ids.""" + + output = model.generate(PROMPT, SamplingParams(temperature=0)) + + stop_token_id_0 = output[0].outputs[0].token_ids[5] + stop_token_id_1 = output[0].outputs[0].token_ids[6] + + stop_token_ids = [stop_token_id_1, stop_token_id_0] + params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + output = model.generate(PROMPT, params) + assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 + + stop_token_ids = [stop_token_id_0, stop_token_id_1] + params = SamplingParams(temperature=0, stop_token_ids=stop_token_ids) + assert output[0].outputs[0].token_ids[-1] == stop_token_id_0 + + +def test_bad_words(model): + """Check that we respect bad words.""" + + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 9ce41a1968bc4..e88e43f5244d1 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -93,10 +93,10 @@ def _validate_supported_sampling_params( params: SamplingParams, ) -> None: # Best of. - if params.best_of is not None: + if params.best_of: raise ValueError("VLLM V1 does not support best_of.") # Bad words. - if params.bad_words is not None: + if params.bad_words: raise ValueError("VLLM V1 does not support bad_words.") def _validate_params( From 612b6bf529c8e56f2f6d25a19c860937f756db54 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 23:20:57 +0000 Subject: [PATCH 3/9] added tests Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 46 +++++++++++++++++++++ vllm/v1/engine/processor.py | 20 ++++++--- vllm/v1/worker/gpu_input_batch.py | 5 +++ 3 files changed, 66 insertions(+), 5 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 8c042d3a2b899..a026a1fe4fe84 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -95,3 +95,49 @@ def test_bad_words(model): with pytest.raises(ValueError): _ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) + + +def test_detokenize(model): + """Check that we reject detokenize=False.""" + + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(detokenize=False)) + + +def test_logits_processor(model): + """Check that we reject logits processor.""" + + # This sample logits processor gives infinite score to the i-th token, + # where i is the length of the input sequence. + # We therefore expect the output token sequence to be [0, 1, 2, ...] + def pick_ith(token_ids, logits): + logits[len(token_ids)] = float("inf") + return logits + + with pytest.raises(ValueError): + _ = model.generate(PROMPT, + SamplingParams(logits_processors=[pick_ith])) + + +def test_allowed_token_ids(model): + """Check that we can use allowed_token_ids.""" + + # This does not currently work due to incorrect implementation. + # TOKEN_ID = 10 + # allowed_token_ids = [TOKEN_ID] + # output = model.generate(PROMPT, SamplingParams( + # allowed_token_ids=allowed_token_ids)) + # assert output[0].outputs[0].token_ids[-1] == TOKEN_ID + + # Reject all allowed token ids + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[10])) + + # Reject negative token id. + with pytest.raises(ValueError): + _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[-1])) + + # Reject out of vocabulary. + with pytest.raises(ValueError): + _ = model.generate(PROMPT, + SamplingParams(allowed_token_ids=[10000000])) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index e88e43f5244d1..bee18e8b20aba 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -83,7 +83,7 @@ def _validate_sampling_params( # Allowed token ids. if (params.allowed_token_ids is not None - and not all(0 <= tid < self.model_config.vocab_size + and not all(0 <= tid < self.model_config.get_vocab_size() for tid in params.allowed_token_ids)): raise ValueError( "allowed_token_ids contains out-of-vocab token id") @@ -92,12 +92,22 @@ def _validate_supported_sampling_params( self, params: SamplingParams, ) -> None: - # Best of. + # Best of not yet supported. if params.best_of: - raise ValueError("VLLM V1 does not support best_of.") - # Bad words. + raise ValueError("VLLM V1 does not yet support best_of.") + # Bad words not yet supported. if params.bad_words: - raise ValueError("VLLM V1 does not support bad_words.") + raise ValueError("VLLM V1 does not yet support bad_words.") + # Skip detokenization not yet supported. + if not params.detokenize: + raise ValueError("VLLM V1 does not yet support detokenize=False.") + # Logits processors not supported. + if params.logits_processors: + raise ValueError("VLLM V1 does not yet support per request " + "logits processors.") + # Allowed token ids is not supported. + if params.allowed_token_ids: + raise ValueError("VLLM V1 does not yet support allowed token ids.") def _validate_params( self, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b0b218d92b927..59a8f22a40907 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -296,6 +296,11 @@ def add_request( if sampling_params.logit_bias is not None: self.logit_bias[req_index] = sampling_params.logit_bias + # FIXME: this implementation is incorrect. We create this mask + # then apply -inf to these specific tokens, which means we never + # select the allowed tokens! We cannot do the reverse, since + # this will impact the requests that do not have allowed_token_ids. + # This feature is currently disabled on V1 (we reject in Processor). if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) if self.allowed_token_ids_mask_cpu_tensor is None: From ef6537cf8ab9551299ba8ceb1acd852aaaf0a167 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 23:24:21 +0000 Subject: [PATCH 4/9] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 15 +++++++++++++++ vllm/v1/engine/processor.py | 6 ++++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index a026a1fe4fe84..a6fc3be8b1677 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -141,3 +141,18 @@ def test_allowed_token_ids(model): with pytest.raises(ValueError): _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[10000000])) + + +def test_priority(model): + """Check that we can use allowed_token_ids.""" + + # This does not currently work due to incorrect implementation. + # TOKEN_ID = 10 + # allowed_token_ids = [TOKEN_ID] + # output = model.generate(PROMPT, SamplingParams( + # allowed_token_ids=allowed_token_ids)) + # assert output[0].outputs[0].token_ids[-1] == TOKEN_ID + + # Reject all allowed token ids + with pytest.raises(ValueError): + _ = model.generate(PROMPT, priority=[1]) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index bee18e8b20aba..472d3d1d8c611 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -147,11 +147,13 @@ def process_inputs( self._validate_lora(lora_request) self._validate_params(params) + if priority != 0: + raise ValueError("V1 does not support priority yet.") + if trace_headers is not None: + raise ValueError("V1 does not support tracing yet.") if arrival_time is None: arrival_time = time.time() - assert priority == 0, "vLLM V1 does not support priority at the moment." - assert trace_headers is None, "vLLM V1 does not support tracing yet." # Process inputs, which includes: # 1. Tokenize text prompt, with LoRA request if one exists. From 1d029404faa4a5c1501cef82505adb6b16eb2751 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 23:25:03 +0000 Subject: [PATCH 5/9] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index a6fc3be8b1677..2c18ae1ed073b 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -144,14 +144,7 @@ def test_allowed_token_ids(model): def test_priority(model): - """Check that we can use allowed_token_ids.""" - - # This does not currently work due to incorrect implementation. - # TOKEN_ID = 10 - # allowed_token_ids = [TOKEN_ID] - # output = model.generate(PROMPT, SamplingParams( - # allowed_token_ids=allowed_token_ids)) - # assert output[0].outputs[0].token_ids[-1] == TOKEN_ID + """Check that we reject requests with priority.""" # Reject all allowed token ids with pytest.raises(ValueError): From 10033b973b07f9ae7fbe54e30437fabdef28c1d8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Mon, 3 Mar 2025 23:28:44 +0000 Subject: [PATCH 6/9] updated! Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index 2c18ae1ed073b..e445b1c0524ca 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -149,3 +149,14 @@ def test_priority(model): # Reject all allowed token ids with pytest.raises(ValueError): _ = model.generate(PROMPT, priority=[1]) + + +def test_seed(model): + """Check that seed impacts randomness.""" + + out_1 = model.generate(PROMPT, SamplingParams(seed=42)) + out_2 = model.generate(PROMPT, SamplingParams(seed=42)) + out_3 = model.generate(PROMPT, SamplingParams(seed=43)) + + assert out_1[0].outputs[0].text == out_2[0].outputs[0].text + assert out_1[0].outputs[0].text != out_3[0].outputs[0].text From f6eccb00d9fdd61f5923ec0489c726c783f662a8 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 5 Mar 2025 01:02:37 +0000 Subject: [PATCH 7/9] remove detokenize, since it is used by lm eval Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 7 ------- vllm/v1/engine/processor.py | 3 --- 2 files changed, 10 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index e445b1c0524ca..dcd412d4ca517 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -97,13 +97,6 @@ def test_bad_words(model): _ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"])) -def test_detokenize(model): - """Check that we reject detokenize=False.""" - - with pytest.raises(ValueError): - _ = model.generate(PROMPT, SamplingParams(detokenize=False)) - - def test_logits_processor(model): """Check that we reject logits processor.""" diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 472d3d1d8c611..fe4df745dbe5b 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -98,9 +98,6 @@ def _validate_supported_sampling_params( # Bad words not yet supported. if params.bad_words: raise ValueError("VLLM V1 does not yet support bad_words.") - # Skip detokenization not yet supported. - if not params.detokenize: - raise ValueError("VLLM V1 does not yet support detokenize=False.") # Logits processors not supported. if params.logits_processors: raise ValueError("VLLM V1 does not yet support per request " From 8b6452f8cefc3f875b64d3d1d1470e03818f0b88 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 5 Mar 2025 08:28:48 +0000 Subject: [PATCH 8/9] address comments Signed-off-by: rshaw@neuralmagic.com --- vllm/v1/engine/processor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index fe4df745dbe5b..fe4ada6f0bd5d 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -100,8 +100,8 @@ def _validate_supported_sampling_params( raise ValueError("VLLM V1 does not yet support bad_words.") # Logits processors not supported. if params.logits_processors: - raise ValueError("VLLM V1 does not yet support per request " - "logits processors.") + raise ValueError("VLLM V1 does not support per request " + "user provided logits processors.") # Allowed token ids is not supported. if params.allowed_token_ids: raise ValueError("VLLM V1 does not yet support allowed token ids.") @@ -148,6 +148,8 @@ def process_inputs( raise ValueError("V1 does not support priority yet.") if trace_headers is not None: raise ValueError("V1 does not support tracing yet.") + if prompt_adapter_request is not None: + raise ValueError("V1 does not support prompt_adapter_request.") if arrival_time is None: arrival_time = time.time() From 51341b99339ff020b560e64bdf96d68befbc5f18 Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Wed, 5 Mar 2025 12:10:01 +0000 Subject: [PATCH 9/9] updated Signed-off-by: rshaw@neuralmagic.com --- tests/v1/sample/test_sampling_params_e2e.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index dcd412d4ca517..e47f13f053160 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -115,16 +115,11 @@ def pick_ith(token_ids, logits): def test_allowed_token_ids(model): """Check that we can use allowed_token_ids.""" - # This does not currently work due to incorrect implementation. - # TOKEN_ID = 10 - # allowed_token_ids = [TOKEN_ID] - # output = model.generate(PROMPT, SamplingParams( - # allowed_token_ids=allowed_token_ids)) - # assert output[0].outputs[0].token_ids[-1] == TOKEN_ID - - # Reject all allowed token ids - with pytest.raises(ValueError): - _ = model.generate(PROMPT, SamplingParams(allowed_token_ids=[10])) + TOKEN_ID = 10 + allowed_token_ids = [TOKEN_ID] + output = model.generate( + PROMPT, SamplingParams(allowed_token_ids=allowed_token_ids)) + assert output[0].outputs[0].token_ids[-1] == TOKEN_ID # Reject negative token id. with pytest.raises(ValueError):