From 080a4bf0cf98f5e8c4d91c9275ad6ba9de7facc5 Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 22 Jan 2025 22:57:57 +0800 Subject: [PATCH] add sync_openai api_server (#365) --- vllm/__init__.py | 2 + vllm/entrypoints/fast_sync_llm.py | 128 ++++++ vllm/entrypoints/sync_openai/api_server.py | 428 +++++++++++++++++++++ 3 files changed, 558 insertions(+) create mode 100644 vllm/entrypoints/fast_sync_llm.py create mode 100644 vllm/entrypoints/sync_openai/api_server.py diff --git a/vllm/__init__.py b/vllm/__init__.py index 45252b93e3d54..521393fde98ec 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -4,6 +4,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM +from vllm.entrypoints.fast_sync_llm import FastSyncLLM from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry @@ -21,6 +22,7 @@ "__version__", "__version_tuple__", "LLM", + "FastSyncLLM", "ModelRegistry", "PromptType", "TextPrompt", diff --git a/vllm/entrypoints/fast_sync_llm.py b/vllm/entrypoints/fast_sync_llm.py new file mode 100644 index 0000000000000..a3c1a455eac39 --- /dev/null +++ b/vllm/entrypoints/fast_sync_llm.py @@ -0,0 +1,128 @@ +import multiprocessing as mp +from queue import Empty +from typing import Union + +import vllm.envs as envs +from vllm.distributed.communication_op import broadcast_tensor_dict +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm.inputs import PromptType, TokensPrompt +from vllm.logger import init_logger +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class FastSyncLLM: + + def __init__( + self, + engine_args: EngineArgs, + input_queue: mp.Queue, + result_queue: mp.Queue, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + self.engine_args = engine_args + self.request_counter = Counter() + + self.input_queue = input_queue + self.result_queue = result_queue + self.finish = False + self.need_restart = False + self.llm_engine: LLMEngine + + def _add_request( + self, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + ) -> None: + if isinstance(inputs, list): + inputs = TokensPrompt(prompt_token_ids=inputs) + self.llm_engine.add_request(request_id, inputs, params) + + def _poll_requests(self): + while True: + if not self.llm_engine.has_unfinished_requests(): + logger.info("No unfinished requests. Waiting...") + (request_id, prompt, sampling_params) = self.input_queue.get() + if self.need_restart and isinstance( + self.llm_engine.model_executor, + MultiprocessingGPUExecutor): + logger.info("Restarting worker loops") + for worker in self.llm_engine.model_executor.workers: + worker.execute_method("start_worker_execution_loop") + self.need_restart = False + + else: + try: + (request_id, prompt, + sampling_params) = self.input_queue.get_nowait() + except Empty: + break + self._add_request(prompt, sampling_params, request_id) + + def run_engine(self): + self.llm_engine = LLMEngine.from_engine_args( + self.engine_args, usage_context=UsageContext.LLM_CLASS) + assert not isinstance( + self.llm_engine.model_executor, + RayGPUExecutor), "Ray is not supported in sync openai mode" + + self.result_queue.put(("Ready", None, None)) + prompt_lens = {} + tokens = {} # type: ignore + log_interval = 100 + poll_interval = envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS + try: + while True: + poll_interval -= 1 + if (self.input_queue.qsize() >= + envs.VLLM_SYNC_SERVER_ACCUM_REQUESTS + or poll_interval <= 0 + or not self.llm_engine.has_unfinished_requests()): + self._poll_requests() + poll_interval = \ + envs.VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS + step_outputs = self.llm_engine.step() + log_interval -= 1 + if log_interval == 0: + log_interval = 100 + logger.info("Step finished. Unfinished requests: %d", + self.llm_engine.get_num_unfinished_requests()) + if not self.llm_engine.has_unfinished_requests(): + logger.info("Broadcast stop") + broadcast_tensor_dict({}, src=0) + self.need_restart = True + for output in step_outputs: + assert len(output.outputs) == 1 # type: ignore + first_out = output.outputs[0] # type: ignore + stats = None + result = first_out.text + tokens[output.request_id] = tokens.get( + output.request_id, 0) + len(first_out.token_ids) + if output.prompt_token_ids is not None: + prompt_lens[output.request_id] = len( + output.prompt_token_ids) + if output.finished: + assert output.request_id in prompt_lens + stats = { + "prompt": prompt_lens[output.request_id], + "tokens": tokens[output.request_id], + "finish_reason": first_out.finish_reason, + "stop_reason": first_out.stop_reason, + } + del prompt_lens[output.request_id] + self.result_queue.put_nowait( + (output.request_id, result, stats)) + + except Exception as e: + logger.error("Error in run_engine: %s", e) + raise e diff --git a/vllm/entrypoints/sync_openai/api_server.py b/vllm/entrypoints/sync_openai/api_server.py new file mode 100644 index 0000000000000..1995e71a3ec44 --- /dev/null +++ b/vllm/entrypoints/sync_openai/api_server.py @@ -0,0 +1,428 @@ +import asyncio +import multiprocessing +import re +import threading +import time +from contextlib import asynccontextmanager +from http import HTTPStatus +from typing import Dict, List, Optional, Union + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse +from fastapi.routing import Mount +from prometheus_client import make_asgi_app + +import vllm +import vllm.envs as envs +from vllm import FastSyncLLM as LLM +from vllm.config import VllmConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.chat_utils import (MultiModalItemTracker, + _parse_chat_message_content, + load_chat_template, + resolve_chat_template_content_format) +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, CompletionRequest, + CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage, + ErrorResponse, ModelCard, ModelList, ModelPermission, UsageInfo) +from vllm.entrypoints.openai.serving_chat import ConversationMessage +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import FlexibleArgumentParser, random_uuid + +mp = multiprocessing.get_context(envs.VLLM_WORKER_MULTIPROC_METHOD) + +logger = init_logger("api_server.py") + + +def put_in_queue(queue, item, loop): + try: + asyncio.run_coroutine_threadsafe(queue.put(item), loop) + except Exception as e: + logger.error("Exception in put_in_queue: %s", e) + raise e + + +class BackgroundRunner: + + def __init__(self): + self.value = 0 + self.engine_args: EngineArgs + self.engine_config: VllmConfig + self.input_queue: multiprocessing.Queue = mp.Queue() + self.result_queue: multiprocessing.Queue = mp.Queue() + self.result_queues: Dict[str, asyncio.Queue] = {} + self.t: threading.Thread = threading.Thread(target=self.thread_proc) + self.loop = None + self.llm: LLM + self.proc: multiprocessing.Process + self.tokenizer = None + self.response_role: str + self.chat_template: Optional[str] + self.chat_template_content_format = "auto" + + def set_response_role(self, role): + self.response_role = role + + def set_engine_args(self, engine_args): + self.engine_args = engine_args + + def add_result_queue(self, id, queue): + self.result_queues[id] = queue + + def remove_result_queues(self, ids): + for id in ids: + assert id in self.result_queues + del self.result_queues[id] + logger.debug("Removed result queue from %d ids. %d remaining", + len(ids), len(self.result_queues)) + + def thread_proc(self): + while True: + req_id, result, stats = self.result_queue.get() + put_in_queue(self.result_queues[req_id], (req_id, result, stats), + self.loop) + + async def run_main(self): + self.llm = LLM( + engine_args=self.engine_args, + input_queue=self.input_queue, + result_queue=self.result_queue, + ) + + self.loop = asyncio.get_event_loop() + self.proc = mp.Process( # type: ignore[attr-defined] + target=self.llm.run_engine) + self.t.start() + self.proc.start() + + async def add_request(self, prompt, sampling_params): + result_queue: asyncio.Queue = asyncio.Queue() + ids = [] + if isinstance(prompt, str) or (isinstance(prompt, list) + and isinstance(prompt[0], int)): + prompt = [prompt] + for p in prompt: + id = random_uuid() + self.add_result_queue(id, result_queue) + self.input_queue.put_nowait((id, p, sampling_params)) + ids.append(id) + return ids, result_queue + + +runner = BackgroundRunner() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + runner.result_queues["Ready"] = asyncio.Queue() + asyncio.create_task(runner.run_main()) + await runner.result_queues["Ready"].get() + del runner.result_queues["Ready"] + runner.engine_config = runner.engine_args.create_engine_config() + + tokenizer = get_tokenizer( + engine_args.tokenizer, + tokenizer_mode=engine_args.tokenizer_mode, + tokenizer_revision=engine_args.tokenizer_revision, + trust_remote_code=engine_args.trust_remote_code, + truncation_side="left") + runner.tokenizer = tokenizer + yield + + +app = FastAPI(lifespan=lifespan) + +# Add prometheus asgi middleware to route /metrics requests +route = Mount("/metrics", make_asgi_app()) +# Workaround for 307 Redirect for /metrics +route.path_regex = re.compile('^/metrics(?P.*)$') +app.routes.append(route) + + +@app.get("/v1/models") +async def show_available_models(): + models = [ + ModelCard(id=runner.engine_args.model, + root=runner.engine_args.model, + permission=[ModelPermission()]) + ] + model_list = ModelList(data=models) + return JSONResponse(content=model_list.model_dump()) + + +@app.get("/version") +async def show_version(): + ver = {"version": vllm.__version__} + return JSONResponse(content=ver) + + +async def _check_model(request: Union[CompletionRequest, + ChatCompletionRequest]): + model = request.model + if model != runner.engine_args.model: + return ErrorResponse(message=f"The model {model} does not exist.", + type="NotFoundError", + code=HTTPStatus.NOT_FOUND) + return None + + +async def completion_generator(model, result_queue, choices, created_time, + ids): + completed = 0 + try: + while True: + request_id, token, stats = await result_queue.get() + + choice_idx = choices[request_id] + res = CompletionStreamResponse(id=request_id, + created=created_time, + model=model, + choices=[ + CompletionResponseStreamChoice( + index=choice_idx, + text=token, + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) + if stats is not None: + res.usage = UsageInfo() + res.usage.completion_tokens = stats.get("tokens", 0) + res.usage.prompt_tokens = stats.get("prompt", 0) + res.usage.total_tokens = ( + res.usage.completion_tokens + # type: ignore + res.usage.prompt_tokens) + res.choices[0].finish_reason = stats["finish_reason"] + res.choices[0].stop_reason = stats["stop_reason"] + completed += 1 + response_json = res.model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + if completed == len(choices): + runner.remove_result_queues(ids) + break + + yield "data: [DONE]\n\n" + except Exception as e: + logger.error("Error in completion_generator: %s", e) + return + + +@app.post("/v1/completions") +async def completions(request: CompletionRequest, raw_request: Request): + error_check_ret = await _check_model(request) + if error_check_ret is not None: + return JSONResponse(content=error_check_ret.model_dump(), + status_code=error_check_ret.code) + + # Build default sampling params + default_sampling_params = ( + runner.engine_config.model_config.get_diff_sampling_param()) + sampling_params = request.to_sampling_params( + default_max_tokens=runner.engine_config.model_config.max_model_len, + logits_processor_pattern=runner.engine_config.model_config. + logits_processor_pattern, + default_sampling_params=default_sampling_params + # TODO: gshtras add - len(prompt_inputs["prompt_token_ids"]) + ) + ids, result_queue = await runner.add_request(request.prompt, + sampling_params) + res = CompletionResponse(model=request.model, + choices=[], + usage=UsageInfo(prompt_tokens=0, + total_tokens=0, + completion_tokens=0)) + choices = {} + for i, id in enumerate(ids): + res.choices.append( + CompletionResponseChoice(index=i, + text="", + finish_reason=None, + stop_reason=None)) + choices[id] = i + completed = 0 + if request.stream: + created_time = int(time.time()) + return StreamingResponse(content=completion_generator( + request.model, result_queue, choices, created_time, ids), + media_type="text/event-stream") + while True: + request_id, token, stats = await result_queue.get() + choice_idx = choices[request_id] + res.choices[choice_idx].text += str(token) + if stats is not None: + res.usage.completion_tokens += stats["tokens"] # type: ignore + res.usage.prompt_tokens += stats["prompt"] # type: ignore + res.choices[choice_idx].finish_reason = stats["finish_reason"] + res.choices[choice_idx].stop_reason = stats["stop_reason"] + completed += 1 + if completed == len(ids): + runner.remove_result_queues(ids) + break + continue + res.usage.total_tokens = ( # type: ignore + res.usage.completion_tokens + res.usage.prompt_tokens) # type: ignore + return res + + +async def chat_completion_generator(model, result_queue, created_time, id): + try: + first_token = ChatCompletionStreamResponse( + id=id, + created=created_time, + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role=runner.response_role), + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) + response_json = first_token.model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + + while True: + request_id, token, stats = await result_queue.get() + assert request_id == id + + res = ChatCompletionStreamResponse( + id=request_id, + created=created_time, + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=token), + logprobs=None, + finish_reason=None, + stop_reason=None) + ], + usage=None) + if stats is not None: + res.usage = UsageInfo() + res.usage.completion_tokens = stats.get("tokens", 0) + res.usage.prompt_tokens = stats.get("prompt", 0) + res.usage.total_tokens = ( + res.usage.completion_tokens + # type: ignore + res.usage.prompt_tokens) + res.choices[0].finish_reason = stats["finish_reason"] + res.choices[0].stop_reason = stats["stop_reason"] + response_json = res.model_dump_json(exclude_unset=True) + yield f"data: {response_json}\n\n" + if stats is not None: + runner.remove_result_queues([id]) + break + + yield "data: [DONE]\n\n" + except Exception as e: + logger.error("Error in completion_generator: %s", e) + return + + +@app.post("/v1/chat/completions") +async def chat_completions(request: ChatCompletionRequest, + raw_request: Request): + error_check_ret = await _check_model(request) + if error_check_ret is not None: + return JSONResponse(content=error_check_ret.model_dump(), + status_code=error_check_ret.code) + + default_sampling_params = ( + runner.engine_config.model_config.get_diff_sampling_param()) + sampling_params = request.to_sampling_params( + default_max_tokens=runner.engine_config.model_config.max_model_len, + logits_processor_pattern=runner.engine_config.model_config. + logits_processor_pattern, + default_sampling_params=default_sampling_params + # TODO: gshtras add len(prompt_inputs["prompt_token_ids"]) + ) + conversation: List[ConversationMessage] = [] + + res = ChatCompletionResponse(model=request.model, + choices=[], + usage=UsageInfo(prompt_tokens=0, + total_tokens=0, + completion_tokens=0)) + + mm_tracker = MultiModalItemTracker(runner.engine_config.model_config, + runner.tokenizer) + chat_template = request.chat_template or runner.chat_template + content_format = resolve_chat_template_content_format( + chat_template, runner.chat_template_content_format, runner.tokenizer) + for msg in request.messages: + parsed_msg = _parse_chat_message_content(msg, mm_tracker, + content_format) + conversation.extend(parsed_msg) + + prompt = runner.tokenizer.apply_chat_template( # type: ignore + conversation=conversation, + chat_template=chat_template, + tokenize=False, + add_generation_prompt=request.add_generation_prompt, + ) + + ids, result_queue = await runner.add_request(prompt, sampling_params) + assert len(ids) == 1 + + if request.stream: + created_time = int(time.time()) + return StreamingResponse(content=chat_completion_generator( + request.model, result_queue, created_time, ids[0]), + media_type="text/event-stream") + + res.choices.append( + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role=runner.response_role, content=""), + finish_reason=None, + stop_reason=None)) + + while True: + _, token, stats = await result_queue.get() + assert res.choices[0].message.content is not None + res.choices[0].message.content += str(token) + if stats is not None: + res.usage.completion_tokens += stats["tokens"] # type: ignore + res.usage.prompt_tokens += stats["prompt"] # type: ignore + res.choices[0].finish_reason = stats["finish_reason"] + res.choices[0].stop_reason = stats["stop_reason"] + runner.remove_result_queues(ids) + break + res.usage.total_tokens = ( # type: ignore + res.usage.completion_tokens + res.usage.prompt_tokens) # type: ignore + return res + + +def parse_args(): + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + engine_args = EngineArgs.from_cli_args(args) + runner.set_engine_args(engine_args) + runner.set_response_role(args.response_role) + runner.chat_template = load_chat_template(args.chat_template) + runner.chat_template_content_format = args.chat_template_content_format + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + uvicorn.run(app, port=args.port, host=args.host)