From 5f460ecefa59c38b8d1abb7891f0767630e30507 Mon Sep 17 00:00:00 2001
From: Nick Hill <nickhill@us.ibm.com>
Date: Wed, 22 Jan 2025 14:22:12 -0800
Subject: [PATCH] [Frontend][V1] Online serving performance improvements
 (#12287)

Signed-off-by: Bowen Wang <abmfy@icloud.com>
---
 vllm/entrypoints/openai/api_server.py |  6 +++
 vllm/entrypoints/openai/protocol.py   | 30 +++++++++------
 vllm/envs.py                          | 11 ++++++
 vllm/v1/engine/async_llm.py           | 54 ++++++++++++++++++---------
 vllm/v1/engine/core_client.py         | 21 +++++++++--
 vllm/v1/engine/output_processor.py    |  6 ++-
 vllm/v1/request.py                    | 18 +++------
 7 files changed, 101 insertions(+), 45 deletions(-)

diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py
index 9bb11907f7402..f510c41503011 100644
--- a/vllm/entrypoints/openai/api_server.py
+++ b/vllm/entrypoints/openai/api_server.py
@@ -1,5 +1,6 @@
 import asyncio
 import atexit
+import gc
 import importlib
 import inspect
 import multiprocessing
@@ -104,6 +105,11 @@ async def _force_log():
             task.add_done_callback(_running_tasks.remove)
         else:
             task = None
+
+        # Mark the startup heap as static so that it's ignored by GC.
+        # Reduces pause times of oldest generation collections.
+        gc.collect()
+        gc.freeze()
         try:
             yield
         finally:
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 14e41346df775..80403f77d5375 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -3,7 +3,7 @@
 import re
 import time
 from argparse import Namespace
-from typing import Any, Dict, List, Literal, Optional, Union
+from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
 
 import torch
 from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -42,23 +42,31 @@ class OpenAIBaseModel(BaseModel):
     # OpenAI API does allow extra fields
     model_config = ConfigDict(extra="allow")
 
+    # Cache class field names
+    field_names: ClassVar[Optional[Set[str]]] = None
+
     @model_validator(mode="before")
     @classmethod
     def __log_extra_fields__(cls, data):
-        if isinstance(data, dict):
+
+        field_names = cls.field_names
+        if field_names is None:
+            if not isinstance(data, dict):
+                return data
             # Get all class field names and their potential aliases
             field_names = set()
             for field_name, field in cls.model_fields.items():
                 field_names.add(field_name)
-                if hasattr(field, 'alias') and field.alias:
-                    field_names.add(field.alias)
-
-            # Compare against both field names and aliases
-            extra_fields = data.keys() - field_names
-            if extra_fields:
-                logger.warning(
-                    "The following fields were present in the request "
-                    "but ignored: %s", extra_fields)
+                if alias := getattr(field, 'alias', None):
+                    field_names.add(alias)
+            cls.field_names = field_names
+
+        # Compare against both field names and aliases
+        if any(k not in field_names for k in data):
+            logger.warning(
+                "The following fields were present in the request "
+                "but ignored: %s",
+                data.keys() - field_names)
         return data
 
 
diff --git a/vllm/envs.py b/vllm/envs.py
index 1e68326b2d908..3a15e00e7b50a 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -73,6 +73,7 @@
     VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
     VLLM_DISABLE_COMPILE_CACHE: bool = False
     VLLM_SERVER_DEV_MODE: bool = False
+    VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
 
 
 def get_default_cache_root():
@@ -474,6 +475,16 @@ def get_default_config_root():
     # e.g. `/reset_prefix_cache`
     "VLLM_SERVER_DEV_MODE":
     lambda: bool(int(os.getenv("VLLM_SERVER_DEV_MODE", "0"))),
+
+    # Controls the maximum number of requests to handle in a
+    # single asyncio task when processing per-token outputs in the
+    # V1 AsyncLLM interface. It is applicable when handling a high
+    # concurrency of streaming requests.
+    # Setting this too high can result in a higher variance of
+    # inter-message latencies. Setting it too low can negatively impact
+    # TTFT and overall throughput.
+    "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
+    lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
 }
 
 # end-env-vars-definition
diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py
index b4d3e441173df..1505b62504a2f 100644
--- a/vllm/v1/engine/async_llm.py
+++ b/vllm/v1/engine/async_llm.py
@@ -2,9 +2,12 @@
 import os
 from typing import AsyncGenerator, List, Mapping, Optional, Type, Union
 
+import numpy as np
+
 from vllm.config import ModelConfig, VllmConfig
 from vllm.engine.arg_utils import AsyncEngineArgs
 from vllm.engine.protocol import EngineClient
+from vllm.envs import VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
 from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType
 from vllm.inputs.preprocess import InputPreprocessor
 from vllm.logger import init_logger
@@ -16,7 +19,7 @@
 from vllm.transformers_utils.tokenizer import AnyTokenizer
 from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
 from vllm.usage.usage_lib import UsageContext
-from vllm.utils import kill_process_tree
+from vllm.utils import cdiv, kill_process_tree
 from vllm.v1.engine.core_client import EngineCoreClient
 from vllm.v1.engine.output_processor import OutputProcessor
 from vllm.v1.engine.processor import Processor
@@ -205,17 +208,15 @@ async def generate(
 
             # The output_handler task pushes items into the queue.
             # This task pulls from the queue and yields to caller.
-            while True:
+            finished = False
+            while not finished:
                 # Note: drain queue without await if possible (avoids
                 # task switching under load which helps performance).
-                out = q.get_nowait() if q.qsize() > 0 else await q.get()
+                out = q.get_nowait() if not q.empty() else await q.get()
 
                 # Note: both OutputProcessor and EngineCore handle their
                 # own request cleanup based on finished.
-                if out.finished:
-                    yield out
-                    break
-
+                finished = out.finished
                 yield out
 
         # If the request is disconnected by the client, the
@@ -233,22 +234,41 @@ async def _run_output_handler(self):
                 # 1) Pull EngineCoreOutputs from the EngineCore.
                 outputs = await self.engine_core.get_output_async()
 
-                # 2) Process EngineCoreOutputs.
-                processed_outputs = self.output_processor.process_outputs(
-                    outputs.outputs)
-                # NOTE: RequestOutputs are pushed to their queues.
-                assert len(processed_outputs.request_outputs) == 0
-
-                # 3) Abort any reqs that finished due to stop strings.
-                await self.engine_core.abort_requests_async(
-                    processed_outputs.reqs_to_abort)
+                # Split outputs into chunks of at most
+                # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
+                # event loop for too long.
+                num_outputs = len(outputs.outputs)
+                if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
+                    slices = (outputs.outputs, )
+                else:
+                    slices = np.array_split(
+                        outputs.outputs,
+                        cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
+
+                iteration_stats = None
+                for i, outputs_slice in enumerate(slices):
+                    # 2) Process EngineCoreOutputs.
+                    processed_outputs = self.output_processor.process_outputs(
+                        outputs_slice, iteration_stats)
+                    # NOTE: RequestOutputs are pushed to their queues.
+                    assert not processed_outputs.request_outputs
+                    iteration_stats = processed_outputs.iteration_stats
+
+                    # Allow other asyncio tasks to run between chunks
+                    if i + 1 < len(slices):
+                        await asyncio.sleep(0)
+
+                    # 3) Abort any reqs that finished due to stop strings.
+                    await self.engine_core.abort_requests_async(
+                        processed_outputs.reqs_to_abort)
 
                 # 4) Logging.
                 # TODO(rob): make into a coroutine and launch it in
                 # background thread once we add Prometheus.
+                assert iteration_stats is not None
                 self._log_stats(
                     scheduler_stats=outputs.scheduler_stats,
-                    iteration_stats=processed_outputs.iteration_stats,
+                    iteration_stats=iteration_stats,
                 )
 
         except Exception as e:
diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py
index 19b89003cc69d..f3b992d6873e7 100644
--- a/vllm/v1/engine/core_client.py
+++ b/vllm/v1/engine/core_client.py
@@ -1,8 +1,9 @@
+import asyncio
 import os
 import signal
 import weakref
 from abc import ABC, abstractmethod
-from typing import List, Type
+from typing import List, Optional, Type
 
 import msgspec
 import zmq
@@ -255,10 +256,24 @@ def __init__(self, vllm_config: VllmConfig,
             log_stats=True,
         )
 
+        self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
+        self.queue_task: Optional[asyncio.Task] = None
+
     async def get_output_async(self) -> EngineCoreOutputs:
+        if self.outputs_queue is None:
+            # Perform IO in separate task to parallelize as much as possible
+            self.outputs_queue = asyncio.Queue()
+
+            async def process_outputs_socket():
+                assert self.outputs_queue is not None
+                while True:
+                    (frame, ) = await self.output_socket.recv_multipart(
+                        copy=False)
+                    self.outputs_queue.put_nowait(frame.buffer)
+
+            self.queue_task = asyncio.create_task(process_outputs_socket())
 
-        frames = await self.output_socket.recv_multipart(copy=False)
-        return self.decoder.decode(frames[0].buffer)
+        return self.decoder.decode(await self.outputs_queue.get())
 
     async def _send_input(self, request_type: EngineCoreRequestType,
                           request: EngineCoreRequestUnion) -> None:
diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py
index 749f4f5043c97..564eab51bd3a8 100644
--- a/vllm/v1/engine/output_processor.py
+++ b/vllm/v1/engine/output_processor.py
@@ -101,6 +101,7 @@ def add_request(
     def process_outputs(
         self,
         engine_core_outputs: List[EngineCoreOutput],
+        iteration_stats: Optional[IterationStats] = None,
     ) -> OutputProcessorOutput:
         """
         Process the EngineCoreOutputs:
@@ -133,7 +134,8 @@ def process_outputs(
 
         request_outputs: List[RequestOutput] = []
         reqs_to_abort: List[str] = []
-        iteration_stats = IterationStats(self.log_stats)
+        if not iteration_stats:
+            iteration_stats = IterationStats(self.log_stats)
         for engine_core_output in engine_core_outputs:
             req_id = engine_core_output.request_id
             req_state = self.request_states.get(req_id)
@@ -175,8 +177,8 @@ def process_outputs(
             iteration_stats=iteration_stats,
         )
 
+    @staticmethod
     def _make_request_output(
-        self,
         request_state: RequestState,
         detokenizer_output: Optional[DetokenizerOutput],
     ) -> Optional[RequestOutput]:
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
index 45450165eaefe..eefcdaf29e753 100644
--- a/vllm/v1/request.py
+++ b/vllm/v1/request.py
@@ -64,6 +64,12 @@ def __init__(
         # recomputing.
         self._kv_block_hashes: List[BlockHashType] = []
 
+        # Read-only views
+        # Prevent directly appending to the these lists since
+        # they should also be updated simultaneously.
+        self.output_token_ids = ConstantList(self._output_token_ids)
+        self.all_token_ids = ConstantList(self._all_token_ids)
+
     @classmethod
     def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
         return cls(
@@ -79,18 +85,6 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
             lora_request=request.lora_request,
         )
 
-    @property
-    def output_token_ids(self) -> ConstantList[int]:
-        # Prevent directly appending to the output_token_ids since
-        # all_token_ids should also be updated simultaneously.
-        return ConstantList(self._output_token_ids)
-
-    @property
-    def all_token_ids(self) -> ConstantList[int]:
-        # Prevent directly appending to the all_token_ids since
-        # output_token_ids should also be updated simultaneously
-        return ConstantList(self._all_token_ids)
-
     def append_output_token_ids(
         self,
         token_ids: Union[int, List[int]],