From 576bc6ab0437e8a7b0241dad90f9479acd441c81 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 3 Feb 2025 21:44:06 +0000 Subject: [PATCH 1/3] adding support for skip tokenizer benchmarking Signed-off-by: root --- benchmarks/benchmark_throughput.py | 53 ++++++++++++++++++------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 658eab6a278c8..afde99f8435f7 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -18,7 +18,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) -from vllm.inputs import TextPrompt +from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict @@ -171,10 +171,13 @@ def run_vllm( llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. - prompts: List[TextPrompt] = [] + prompts: List[TextPrompt | TokensPrompt] = [] sampling_params: List[SamplingParams] = [] for request in requests: prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) sampling_params.append( @@ -229,11 +232,14 @@ async def run_vllm_async( engine_args, disable_frontend_multiprocessing) as llm: # Add the requests to the engine. - prompts: List[TextPrompt] = [] + prompts: List[TextPrompt | TokensPrompt] = [] sampling_params: List[SamplingParams] = [] lora_requests: List[Optional[LoRARequest]] = [] for request in requests: prompts.append( + TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data) + if "prompt_token_ids" in request.prompt else \ TextPrompt(prompt=request.prompt, multi_modal_data=request.multi_modal_data)) sampling_params.append( @@ -362,24 +368,29 @@ def main(args: argparse.Namespace): random.randint(0, vocab_size - 1) for _ in range(args.input_len) ] - # As tokenizer may add additional tokens like BOS, we need to try - # different lengths to get the desired input length. - for _ in range(5): # Max attempts to correct - candidate_prompt = request_tokenizer.decode(candidate_ids) - tokenized_len = len(request_tokenizer.encode(candidate_prompt)) - - if tokenized_len == args.input_len: - break - - # Adjust length based on difference - diff = args.input_len - tokenized_len - if diff > 0: - candidate_ids.extend([ - random.randint(100, vocab_size - 100) - for _ in range(diff) - ]) - else: - candidate_ids = candidate_ids[:diff] + + candidate_prompt = {"prompt_token_ids": candidate_ids} + + if not args.skip_tokenizer_init: + # As tokenizer may add additional tokens like BOS, we need + # to try different lengths to get the desired input length. + for _ in range(5): # Max attempts to correct + candidate_prompt = request_tokenizer.decode(candidate_ids) + tokenized_len = len( + request_tokenizer.encode(candidate_prompt)) + + if tokenized_len == args.input_len: + break + + # Adjust length based on difference + diff = args.input_len - tokenized_len + if diff > 0: + candidate_ids.extend([ + random.randint(100, vocab_size - 100) + for _ in range(diff) + ]) + else: + candidate_ids = candidate_ids[:diff] requests.append( SampleRequest(prompt=candidate_prompt, prompt_len=args.input_len, From 18c8ef66651e4fa886d2f8c64a215921f1f52be3 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 4 Mar 2025 22:49:42 +0000 Subject: [PATCH 2/3] acted up on comments in PR Signed-off-by: Aleksandr Malyshev --- benchmarks/benchmark_throughput.py | 6 +++--- vllm/engine/arg_utils.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index afde99f8435f7..f34636b083e70 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,7 +6,7 @@ import random import time from functools import cache -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch import uvloop @@ -171,7 +171,7 @@ def run_vllm( llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. - prompts: List[TextPrompt | TokensPrompt] = [] + prompts: List[Union[TextPrompt, TokensPrompt]] = [] sampling_params: List[SamplingParams] = [] for request in requests: prompts.append( @@ -232,7 +232,7 @@ async def run_vllm_async( engine_args, disable_frontend_multiprocessing) as llm: # Add the requests to the engine. - prompts: List[TextPrompt | TokensPrompt] = [] + prompts: List[Union[TextPrompt, TokensPrompt]] = [] sampling_params: List[SamplingParams] = [] lora_requests: List[Optional[LoRARequest]] = [] for request in requests: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 40c6fb4567993..1cf6448708346 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -255,7 +255,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--skip-tokenizer-init', action='store_true', - help='Skip initialization of tokenizer and detokenizer.') + help='Skip tokenizer\detokenizer use during model inference. ' + 'Use of tokenizer\detokenizer increases host overhead and might ' + 'end up of GPU underutilization') parser.add_argument( '--revision', type=nullable_str, From 5e16d39fd293a280a9058d9f4eea7354d93337a0 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 6 Mar 2025 17:20:45 +0000 Subject: [PATCH 3/3] description change Signed-off-by: Aleksandr Malyshev --- vllm/engine/arg_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 69a86c9d7e78d..d9e094485e2fa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -274,9 +274,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--skip-tokenizer-init', action='store_true', - help='Skip tokenizer\detokenizer use during model inference. ' - 'Use of tokenizer\detokenizer increases host overhead and might ' - 'end up of GPU underutilization') + help='Skip initialization of tokenizer and detokenizer. ' + 'Expects valid prompt_token_ids and None for prompt from ' + 'the input. The generated output will contain token ids.') parser.add_argument( '--revision', type=nullable_str,