Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Kill the server on engine death #6594

Merged
merged 15 commits into from
Aug 8, 2024
Merged
47 changes: 47 additions & 0 deletions tests/entrypoints/openai/test_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json
import os

import openai
import pytest

from ...utils import RemoteOpenAIServer

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.mark.asyncio
async def test_shutdown_on_engine_failure(tmp_path):
# Use a bad adapter to crash the engine
# (This test will fail when that bug is fixed)
adapter_path = tmp_path / "bad_adapter"
os.mkdir(adapter_path)
with open(adapter_path / "adapter_model_config.json", "w") as f:
json.dump({"not": "real"}, f)
with open(adapter_path / "adapter_model.safetensors", "wb") as f:
f.write(b"this is fake")

# dtype, max-len etc set so that this can run in CI
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--enforce-eager",
"--max-num-seqs",
"128",
"--enable-lora",
"--lora-modules",
f"bad-adapter={tmp_path / 'bad_adapter'}",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
client = remote_server.get_async_client()

with pytest.raises(openai.APIConnectionError):
# This crashes the engine
await client.completions.create(model="bad-adapter",
prompt="Hello, my name is")

# Now the server should shut down
return_code = remote_server.proc.wait(timeout=1)
assert return_code is not None
6 changes: 4 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _log_task_completion(task: asyncio.Task,
error_callback(exception)
raise AsyncEngineDeadError(
"Task finished unexpectedly. This should never happen! "
"Please open an issue on Github. See stack trace above for the"
"Please open an issue on Github. See stack trace above for the "
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This has just triggered me every time I see the log)

"actual cause.") from e


Expand Down Expand Up @@ -132,7 +132,9 @@ def propagate_exception(self,
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
else:
for rid, stream in self._request_streams.items():
# NB: list() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid, stream in list(self._request_streams.items()):
stream.put(exc)
self.abort_request(rid)

Expand Down
1 change: 1 addition & 0 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ async def run_server(args: Namespace,

shutdown_task = await serve_http(
app,
engine=engine,
host=args.host,
port=args.port,
log_level=args.log_level,
Expand Down
44 changes: 42 additions & 2 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any

import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Response

from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger

logger = init_logger(__name__)


async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
methods = getattr(route, "methods", None)
Expand All @@ -23,6 +28,7 @@ async def serve_http(app: FastAPI, **uvicorn_kwargs: Any):

config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)

loop = asyncio.get_running_loop()

Expand All @@ -44,3 +50,37 @@ async def dummy_shutdown() -> None:
except asyncio.CancelledError:
logger.info("Gracefully stopping http server")
return server.shutdown()


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
"""Adds handlers for fatal errors that should crash the server"""

@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "
"process")
# See discussions here on shutting down a uvicorn server
# https://github.com/encode/uvicorn/discussions/1103
# In this case we cannot await the server shutdown here because
# this handler must first return to close the connection for
# this request.
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)

@app.exception_handler(AsyncEngineDeadError)
async def engine_dead_handler(_, __):
"""Kill the server if the async engine is already dead. It will
not handle any further requests."""
if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH:
logger.fatal("AsyncLLMEngine is already dead, terminating server "
"process")
server.should_exit = True

return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR)
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:

shutdown_task = await serve_http(
app,
engine=async_engine_client,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
Expand Down
26 changes: 24 additions & 2 deletions vllm/entrypoints/openai/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ async def setup(self):

# Wait until server is ready.
await self.wait_for_server()
self._errored = False

# Get the configs.
self.model_config = await self._get_model_config_rpc()
Expand Down Expand Up @@ -169,15 +170,15 @@ async def _get_scheduler_config_rpc(self) -> SchedulerConfig:
expected_type=SchedulerConfig,
error_message="Could not get SchedulerConfig from RPC Server")

async def _get_lora_config_rpc(self):
async def _get_lora_config_rpc(self) -> LoRAConfig:
"""Get LoRAConfig from the RPCServer"""

return await self._send_get_data_rpc_request(
RPCUtilityRequest.GET_LORA_CONFIG,
expected_type=LoRAConfig,
error_message="Could not get LoRAConfig from RPC Server")

async def _is_tracing_enabled_rpc(self) -> ParallelConfig:
async def _is_tracing_enabled_rpc(self) -> bool:
"""Get is_tracing_enabled flag from the RPCServer"""

return await self._send_get_data_rpc_request(
Expand All @@ -200,6 +201,18 @@ async def do_log_stats(self):
request=RPCUtilityRequest.DO_LOG_STATS,
error_message="RPCRequest DO_LOG_STATS failed.")

@property
def is_running(self) -> bool:
return not self._errored

@property
def is_stopped(self) -> bool:
return self._errored

@property
def errored(self) -> bool:
return self._errored

async def generate(
self,
inputs: PromptInputs,
Expand Down Expand Up @@ -233,6 +246,15 @@ async def generate(
request_output = cloudpickle.loads(message)

if isinstance(request_output, Exception):
# On exception, check if the server is still healthy.
# Use this to set the sync `is_running` and `errored`
# properties.
try:
await self.check_health()
except Exception:
self._errored = True
# NB: do before raising here so that the flag is set
# by the time the caller receives this exception
raise request_output

finished = request_output.finished
Expand Down
19 changes: 11 additions & 8 deletions vllm/entrypoints/openai/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ async def is_server_ready(self, identity):

async def abort(self, identity, request: RPCAbortRequest):
"""Abort request and notify the client of success."""
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)

# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])
try:
# Abort the request in the llm engine.
await self.engine.abort(request.request_id)
except Exception:
logger.warning("Failed to abort request %s", request.request_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good for this to be logger.exception(...) so that the stacktrace is logged.

But don't want to trigger another whole round of CI tests just for this, we can address it in follow-on cleanup of the zmq decoupling that's being done anyhow.

finally:
# Send confirmation to the client.
await self.socket.send_multipart([
identity,
cloudpickle.dumps(VLLM_RPC_SUCCESS_STR),
])

async def generate(self, identity, generate_request: RPCGenerateRequest):
try:
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
NVCC_THREADS: Optional[str] = None
VLLM_USE_PRECOMPILED: bool = False
VLLM_NO_DEPRECATION_WARNING: bool = False
VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False
CMAKE_BUILD_TYPE: Optional[str] = None
VERBOSE: bool = False
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
Expand Down Expand Up @@ -335,6 +336,11 @@ def get_default_config_root():
"VLLM_NO_DEPRECATION_WARNING":
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),

# If set, the OpenAI API server will stay alive even after the underlying
# AsyncLLMEngine errors and stops serving requests
"VLLM_KEEP_ALIVE_ON_ENGINE_DEATH":
lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)),

# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
# the user to specify a max sequence length greater than
# the max length derived from the model's config.json.
Expand Down
Loading