diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py new file mode 100644 index 0000000000000..6dff1cfbe7f75 --- /dev/null +++ b/tests/entrypoints/openai/test_shutdown.py @@ -0,0 +1,47 @@ +import json +import os + +import openai +import pytest + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.mark.asyncio +async def test_shutdown_on_engine_failure(tmp_path): + # Use a bad adapter to crash the engine + # (This test will fail when that bug is fixed) + adapter_path = tmp_path / "bad_adapter" + os.mkdir(adapter_path) + with open(adapter_path / "adapter_model_config.json", "w") as f: + json.dump({"not": "real"}, f) + with open(adapter_path / "adapter_model.safetensors", "wb") as f: + f.write(b"this is fake") + + # dtype, max-len etc set so that this can run in CI + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + "--enable-lora", + "--lora-modules", + f"bad-adapter={tmp_path / 'bad_adapter'}", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.APIConnectionError): + # This crashes the engine + await client.completions.create(model="bad-adapter", + prompt="Hello, my name is") + + # Now the server should shut down + return_code = remote_server.proc.wait(timeout=1) + assert return_code is not None diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b4a9520e623ea..ef82c3dfd0b54 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task, error_callback(exception) raise AsyncEngineDeadError( "Task finished unexpectedly. This should never happen! " - "Please open an issue on Github. See stack trace above for the" + "Please open an issue on Github. See stack trace above for the " "actual cause.") from e @@ -132,7 +132,9 @@ def propagate_exception(self, self._request_streams[request_id].put(exc) self.abort_request(request_id) else: - for rid, stream in self._request_streams.items(): + # NB: list() used here because self.abort_request pops the stream + # out of self._request_streams, so we can't iterate on it directly + for rid, stream in list(self._request_streams.items()): stream.put(exc) self.abort_request(rid) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index c5850b2de069e..f6e8a417b648c 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -118,6 +118,7 @@ async def run_server(args: Namespace, shutdown_task = await serve_http( app, + engine=engine, host=args.host, port=args.port, log_level=args.log_level, diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 00826762f76a1..8e97ae717660c 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -1,16 +1,21 @@ import asyncio import signal +from http import HTTPStatus from typing import Any import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Response +from vllm import envs +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.protocol import AsyncEngineClient from vllm.logger import init_logger logger = init_logger(__name__) -async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): +async def serve_http(app: FastAPI, engine: AsyncEngineClient, + **uvicorn_kwargs: Any): logger.info("Available routes are:") for route in app.routes: methods = getattr(route, "methods", None) @@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): config = uvicorn.Config(app, **uvicorn_kwargs) server = uvicorn.Server(config) + _add_shutdown_handlers(app, server, engine) loop = asyncio.get_running_loop() @@ -44,3 +50,37 @@ async def dummy_shutdown() -> None: except asyncio.CancelledError: logger.info("Gracefully stopping http server") return server.shutdown() + + +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server, + engine: AsyncEngineClient) -> None: + """Adds handlers for fatal errors that should crash the server""" + + @app.exception_handler(RuntimeError) + async def runtime_error_handler(_, __): + """On generic runtime error, check to see if the engine has died. + It probably has, in which case the server will no longer be able to + handle requests. Trigger a graceful shutdown with a SIGTERM.""" + if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored + and not engine.is_running): + logger.fatal("AsyncLLMEngine has failed, terminating server " + "process") + # See discussions here on shutting down a uvicorn server + # https://github.com/encode/uvicorn/discussions/1103 + # In this case we cannot await the server shutdown here because + # this handler must first return to close the connection for + # this request. + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(AsyncEngineDeadError) + async def engine_dead_handler(_, __): + """Kill the server if the async engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("AsyncLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d44604b12fb69..1a0addfedc55f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None: shutdown_task = await serve_http( app, + engine=async_engine_client, host=args.host, port=args.port, log_level=args.uvicorn_log_level, diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index d69b202e2d1bb..64a20b33d8f3e 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -33,6 +33,7 @@ async def setup(self): # Wait until server is ready. await self.wait_for_server() + self._errored = False # Get the configs. self.model_config = await self._get_model_config_rpc() @@ -169,7 +170,7 @@ async def _get_scheduler_config_rpc(self) -> SchedulerConfig: expected_type=SchedulerConfig, error_message="Could not get SchedulerConfig from RPC Server") - async def _get_lora_config_rpc(self): + async def _get_lora_config_rpc(self) -> LoRAConfig: """Get LoRAConfig from the RPCServer""" return await self._send_get_data_rpc_request( @@ -177,7 +178,7 @@ async def _get_lora_config_rpc(self): expected_type=LoRAConfig, error_message="Could not get LoRAConfig from RPC Server") - async def _is_tracing_enabled_rpc(self) -> ParallelConfig: + async def _is_tracing_enabled_rpc(self) -> bool: """Get is_tracing_enabled flag from the RPCServer""" return await self._send_get_data_rpc_request( @@ -200,6 +201,18 @@ async def do_log_stats(self): request=RPCUtilityRequest.DO_LOG_STATS, error_message="RPCRequest DO_LOG_STATS failed.") + @property + def is_running(self) -> bool: + return not self._errored + + @property + def is_stopped(self) -> bool: + return self._errored + + @property + def errored(self) -> bool: + return self._errored + async def generate( self, inputs: PromptInputs, @@ -233,6 +246,15 @@ async def generate( request_output = cloudpickle.loads(message) if isinstance(request_output, Exception): + # On exception, check if the server is still healthy. + # Use this to set the sync `is_running` and `errored` + # properties. + try: + await self.check_health() + except Exception: + self._errored = True + # NB: do before raising here so that the flag is set + # by the time the caller receives this exception raise request_output finished = request_output.finished diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 617c9b7070e2c..c4f46de4c2229 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -96,14 +96,17 @@ async def is_server_ready(self, identity): async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" - # Abort the request in the llm engine. - await self.engine.abort(request.request_id) - - # Send confirmation to the client. - await self.socket.send_multipart([ - identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + try: + # Abort the request in the llm engine. + await self.engine.abort(request.request_id) + except Exception: + logger.warning("Failed to abort request %s", request.request_id) + finally: + # Send confirmation to the client. + await self.socket.send_multipart([ + identity, + cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), + ]) async def generate(self, identity, generate_request: RPCGenerateRequest): try: diff --git a/vllm/envs.py b/vllm/envs.py index 81f30b1d42a13..26d0c33707fea 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -49,6 +49,7 @@ NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False + VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False @@ -335,6 +336,11 @@ def get_default_config_root(): "VLLM_NO_DEPRECATION_WARNING": lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + # If set, the OpenAI API server will stay alive even after the underlying + # AsyncLLMEngine errors and stops serving requests + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": + lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows # the user to specify a max sequence length greater than # the max length derived from the model's config.json.