Skip to content

Commit

Permalink
feat: implement slim openai api fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
Avram Tudor committed Feb 26, 2025
1 parent dcc5601 commit 1c49115
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 91 deletions.
128 changes: 64 additions & 64 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ flashrank = "^0.2.10"
langchain = "0.3.17"
langchain-community = "^0.3.16"
langchain-huggingface = "^0.1.2"
langchain-openai = "0.2.10"
langchain-openai = "0.3.7"
oci = "^2.144.0"
prometheus-client = "0.21.0"
prometheus-fastapi-instrumentator = "7.0.0"
Expand Down
4 changes: 2 additions & 2 deletions requirements-vllm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ jsonpointer==3.0.0 ; python_version >= "3.11" and python_version < "3.12"
jsonschema-specifications==2024.10.1 ; python_version >= "3.11" and python_version < "3.12"
jsonschema==4.23.0 ; python_version >= "3.11" and python_version < "3.12"
langchain-community==0.3.16 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.34 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.40 ; python_version >= "3.11" and python_version < "3.12"
langchain-huggingface==0.1.2 ; python_version >= "3.11" and python_version < "3.12"
langchain-openai==0.2.10 ; python_version >= "3.11" and python_version < "3.12"
langchain-openai==0.3.7 ; python_version >= "3.11" and python_version < "3.12"
langchain-text-splitters==0.3.6 ; python_version >= "3.11" and python_version < "3.12"
langchain==0.3.17 ; python_version >= "3.11" and python_version < "3.12"
langsmith==0.3.8 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ joblib==1.4.2 ; python_version >= "3.11" and python_version < "3.12"
jsonpatch==1.33 ; python_version >= "3.11" and python_version < "3.12"
jsonpointer==3.0.0 ; python_version >= "3.11" and python_version < "3.12"
langchain-community==0.3.16 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.34 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.40 ; python_version >= "3.11" and python_version < "3.12"
langchain-huggingface==0.1.2 ; python_version >= "3.11" and python_version < "3.12"
langchain-openai==0.2.10 ; python_version >= "3.11" and python_version < "3.12"
langchain-openai==0.3.7 ; python_version >= "3.11" and python_version < "3.12"
langchain-text-splitters==0.3.6 ; python_version >= "3.11" and python_version < "3.12"
langchain==0.3.17 ; python_version >= "3.11" and python_version < "3.12"
langsmith==0.3.8 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
46 changes: 25 additions & 21 deletions skynet/modules/ttt/openai_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skynet.auth.bearer import JWTBearer
from skynet.env import bypass_auth, llama_n_ctx, llama_path, openai_api_base_url, openai_api_port, use_oci, use_vllm
from skynet.logs import get_logger
from skynet.modules.ttt.openai_api.slim_router import router as slim_router
from skynet.utils import create_app, dependencies, responses

log = get_logger(__name__)
Expand All @@ -20,6 +21,7 @@

def initialize():
if not use_vllm:
app.include_router(slim_router)
return

from vllm.entrypoints.openai.api_server import router as vllm_router
Expand Down Expand Up @@ -72,29 +74,31 @@ async def is_ready():

bearer = JWTBearer()

if use_vllm:

@app.middleware('http')
async def proxy_middleware(request: Request, call_next):
if request.url.path in whitelisted_routes:
return await call_next(request)
@app.middleware('http')
async def proxy_middleware(request: Request, call_next):
if request.url.path in whitelisted_routes:
return await call_next(request)

if not bypass_auth:
try:
await bearer.__call__(request)
except HTTPException as e:
return JSONResponse(content=responses.get(e.status_code), status_code=e.status_code)

if not bypass_auth:
try:
await bearer.__call__(request)
url = f'{openai_api_base_url}{request.url.path.replace("/openai", "")}'
response = await http_client.request(
request.method, url, headers=request.headers, data=await request.body()
)

return StreamingResponse(response.content, status_code=response.status, headers=response.headers)
except ClientConnectorError as e:
return JSONResponse(content=str(e), status_code=500)
except HTTPException as e:
return JSONResponse(content=responses.get(e.status_code), status_code=e.status_code)

try:
url = f'{openai_api_base_url}{request.url.path.replace("/openai", "")}'
response = await http_client.request(request.method, url, headers=request.headers, data=await request.body())

return StreamingResponse(response.content, status_code=response.status, headers=response.headers)
except ClientConnectorError as e:
return JSONResponse(content=str(e), status_code=500)
except HTTPException as e:
return JSONResponse(content=e.detail, status_code=e.status_code)
except Exception as e:
return JSONResponse(content=str(e), status_code=500)

return JSONResponse(content=e.detail, status_code=e.status_code)
except Exception as e:
return JSONResponse(content=str(e), status_code=500)

__all__ = ['app', 'initialize', 'is_ready']
__all__ = ['app', 'initialize', 'is_ready']
38 changes: 38 additions & 0 deletions skynet/modules/ttt/openai_api/slim_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import List, Optional

from fastapi import Request
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from skynet.logs import get_logger
from skynet.modules.ttt.processor import process_chat_completion
from skynet.utils import get_customer_id, get_router

router = get_router()
log = get_logger(__name__)


class ChatMessage(BaseModel):
content: Optional[str] = None


class ChatCompletionResponseChoice(BaseModel):
message: ChatMessage


class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice]


class ChatCompletionRequest(BaseModel):
max_completion_tokens: Optional[int] = None
messages: List[ChatCompletionMessageParam]


@router.post('/v1/chat/completions')
async def create_chat_completion(chat_request: ChatCompletionRequest, request: Request):
response = await process_chat_completion(
chat_request.messages, get_customer_id(request), max_completion_tokens=chat_request.max_completion_tokens
)

return ChatCompletionResponse(choices=[ChatCompletionResponseChoice(message=ChatMessage(content=response))])
14 changes: 13 additions & 1 deletion skynet/modules/ttt/processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from operator import itemgetter
from typing import Optional
from typing import List, Optional

from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import ChatPromptTemplate
Expand All @@ -9,6 +9,7 @@
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.output_parsers import StrOutputParser
from openai.types.chat import ChatCompletionMessageParam

from skynet.env import llama_n_ctx
from skynet.logs import get_logger
Expand Down Expand Up @@ -180,3 +181,14 @@ async def process(payload: DocumentPayload, job_type: JobType, customer_id: str
raise ValueError(f'Invalid job type {job_type}')

return result


async def process_chat_completion(
messages: List[ChatCompletionMessageParam], customer_id: Optional[str] = None, **model_kwargs
) -> str:
llm = LLMSelector.select(customer_id, **model_kwargs)

chain = llm | StrOutputParser()
result = await chain.ainvoke(messages)

return result

0 comments on commit 1c49115

Please sign in to comment.