From 026841447cd2e143b6d5d2fa9621b3bd1d975e25 Mon Sep 17 00:00:00 2001 From: aisensiy Date: Mon, 9 Oct 2023 10:59:22 +0800 Subject: [PATCH] Support CORS for openai api server (#481) * Support CORS for openai api server * Remove unnecessary var * Add CORS support follow the same style with vllm --- lmdeploy/serve/openai/api_server.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e1af990a5e..647c36609c 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -3,11 +3,12 @@ import os import time from http import HTTPStatus -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, List, Optional import fire import uvicorn from fastapi import BackgroundTasks, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from lmdeploy.serve.async_engine import AsyncEngine @@ -321,7 +322,11 @@ def main(model_path: str, server_name: str = 'localhost', server_port: int = 23333, instance_num: int = 32, - tp: int = 1): + tp: int = 1, + allow_origins: List[str] = ['*'], + allow_credentials: bool = True, + allow_methods: List[str] = ['*'], + allow_headers: List[str] = ['*']): """An example to perform model inference through the command line interface. @@ -331,7 +336,20 @@ def main(model_path: str, server_port (int): server port instance_num (int): number of instances of turbomind model tp (int): tensor parallel + allow_origins (List[str]): a list of allowed origins for CORS + allow_credentials (bool): whether to allow credentials for CORS + allow_methods (List[str]): a list of allowed HTTP methods for CORS + allow_headers (List[str]): a list of allowed HTTP headers for CORS """ + if allow_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=allow_origins, + allow_credentials=allow_credentials, + allow_methods=allow_methods, + allow_headers=allow_headers, + ) + VariableInterface.async_engine = AsyncEngine(model_path=model_path, instance_num=instance_num, tp=tp)