Skip to content

Commit

Permalink
Support CORS for openai api server (#481)
Browse files Browse the repository at this point in the history
* Support CORS for openai api server

* Remove unnecessary var

* Add CORS support follow the same style with vllm
  • Loading branch information
aisensiy authored Oct 9, 2023
1 parent b58a9df commit 0268414
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import time
from http import HTTPStatus
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, List, Optional

import fire
import uvicorn
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse

from lmdeploy.serve.async_engine import AsyncEngine
Expand Down Expand Up @@ -321,7 +322,11 @@ def main(model_path: str,
server_name: str = 'localhost',
server_port: int = 23333,
instance_num: int = 32,
tp: int = 1):
tp: int = 1,
allow_origins: List[str] = ['*'],
allow_credentials: bool = True,
allow_methods: List[str] = ['*'],
allow_headers: List[str] = ['*']):
"""An example to perform model inference through the command line
interface.
Expand All @@ -331,7 +336,20 @@ def main(model_path: str,
server_port (int): server port
instance_num (int): number of instances of turbomind model
tp (int): tensor parallel
allow_origins (List[str]): a list of allowed origins for CORS
allow_credentials (bool): whether to allow credentials for CORS
allow_methods (List[str]): a list of allowed HTTP methods for CORS
allow_headers (List[str]): a list of allowed HTTP headers for CORS
"""
if allow_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=allow_origins,
allow_credentials=allow_credentials,
allow_methods=allow_methods,
allow_headers=allow_headers,
)

VariableInterface.async_engine = AsyncEngine(model_path=model_path,
instance_num=instance_num,
tp=tp)
Expand Down

0 comments on commit 0268414

Please sign in to comment.