diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index aabce64ff776e..d8353cf1f7137 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -7,7 +7,7 @@ import random import time from functools import cache -from typing import Any, Optional +from typing import Any, Optional, Union import torch import uvloop @@ -20,7 +20,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) -from vllm.inputs import TextPrompt +from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict @@ -178,10 +178,13 @@ def run_vllm( "Please ensure that max_model_len is greater than the sum of" " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: list[TextPrompt] = [] + prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) sampling_params.append( @@ -242,11 +245,14 @@ async def run_vllm_async( " prompt_len and expected_output_len for all requests.") # Add the requests to the engine. - prompts: list[TextPrompt] = [] + prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) sampling_params.append( @@ -393,24 +399,29 @@ def main(args: argparse.Namespace): random.randint(0, vocab_size - 1) for _ in range(args.input_len) ] - # As tokenizer may add additional tokens like BOS, we need to try - # different lengths to get the desired input length. - for _ in range(5): # Max attempts to correct - candidate_prompt = request_tokenizer.decode(candidate_ids) - tokenized_len = len(request_tokenizer.encode(candidate_prompt)) - - if tokenized_len == args.input_len: - break - - # Adjust length based on difference - diff = args.input_len - tokenized_len - if diff > 0: - candidate_ids.extend([ - random.randint(100, vocab_size - 100) - for _ in range(diff) - ]) - else: - candidate_ids = candidate_ids[:diff] + + candidate_prompt = {"prompt_token_ids": candidate_ids} + + if not args.skip_tokenizer_init: + # As tokenizer may add additional tokens like BOS, we need + # to try different lengths to get the desired input length. + for _ in range(5): # Max attempts to correct + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len( + request_tokenizer.encode(candidate_prompt)) + + if tokenized_len == args.input_len: + break + + # Adjust length based on difference + diff = args.input_len - tokenized_len + if diff > 0: + candidate_ids.extend([ + random.randint(100, vocab_size - 100) + for _ in range(diff) + ]) + else: + candidate_ids = candidate_ids[:diff] requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 989eb4dbfd148..d9e094485e2fa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -274,7 +274,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--skip-tokenizer-init', action='store_true', - help='Skip initialization of tokenizer and detokenizer.') + help='Skip initialization of tokenizer and detokenizer. ' + 'Expects valid prompt_token_ids and None for prompt from ' + 'the input. The generated output will contain token ids.') parser.add_argument( '--revision', type=nullable_str,