Skip to content

Commit f6deb00

Browse files
committed
implement streaming parameter
1 parent 1dee9c4 commit f6deb00

File tree

4 files changed

+59
-10
lines changed

4 files changed

+59
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ python-multipart = "^0.0.6"
2121
python-dotenv = "^1.0.0"
2222
faker = "^18.11.1"
2323
requests = "^2.31.0"
24+
openai = "^0.27.8"
2425

2526
[tool.poetry.group.dev.dependencies]
2627
mypy = "^1.4.0"
2728
black = "^23.3.0"
2829
isort = "^5.12.0"
2930
pytest = "^7.3.2"
30-
openai = "^0.27.8"
3131
flake8 = "^6.0.0"
3232
types-python-jose = "^3.3.4.7"
3333
types-passlib = "^1.7.7.12"

vector_embedding_server/openai_like_api_models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class Usage(BaseModel):
1515
total_tokens: int
1616

1717

18+
class CompletionUsage(Usage):
19+
completion_tokens: int
20+
21+
1822
class EmbeddingResponse(BaseModel):
1923
object: str
2024
data: list[EmbeddingData]
@@ -85,4 +89,4 @@ class ChatCompletionResponse(BaseModel):
8589
created: int
8690

8791
choices: list[Choice]
88-
usage: Usage
92+
usage: CompletionUsage

vector_embedding_server/server.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import json
22
import os
33
from pathlib import Path
4+
from typing import Iterator, cast
45

5-
import requests
6+
import openai
67
from dotenv import load_dotenv
78
from fastapi import Depends, FastAPI, Request
8-
from fastapi.responses import HTMLResponse
9+
from fastapi.responses import HTMLResponse, StreamingResponse
910
from fastapi.templating import Jinja2Templates
1011

1112
from vector_embedding_server.auth import (
@@ -23,6 +24,7 @@
2324
EmbeddingResponse,
2425
Usage,
2526
)
27+
from vector_embedding_server.streaming_models import ChatCompletionStreamingResponse
2628

2729
load_dotenv()
2830

@@ -31,6 +33,9 @@
3133
HASHED_PASSWORD = os.environ["HASHED_PASSWORD"]
3234
LANGUAGE_MODEL_SERVER = os.environ["LANGUAGE_MODEL_SERVER"]
3335

36+
openai.api_base = f"{LANGUAGE_MODEL_SERVER}/v1"
37+
openai.api_key = "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"
38+
3439

3540
FAKE_USERS_DB = {
3641
USERNAME: User(
@@ -84,16 +89,30 @@ async def create_embedding(
8489

8590

8691
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
87-
def chat_completion_proxy(
92+
async def chat_completion_proxy(
8893
chat_completion_input: ChatCompletionInput,
8994
current_user: str = Depends(get_current_user_wrapper(FAKE_USERS_DB)),
9095
) -> ChatCompletionResponse:
91-
response = requests.post(
92-
url=f"{LANGUAGE_MODEL_SERVER}/v1/chat/completions",
93-
json=json.loads(chat_completion_input.json()),
96+
response = openai.ChatCompletion.create( # type: ignore
97+
**json.loads(chat_completion_input.json())
98+
)
99+
if not chat_completion_input.stream:
100+
return ChatCompletionResponse(**response)
101+
102+
def event_stream() -> Iterator[bytes]:
103+
for chunk in response:
104+
resp = ChatCompletionStreamingResponse(**chunk)
105+
if resp.choices[0].finish_reason is None:
106+
yield ("data: " + resp.json() + "\r\n\r\n").encode("utf-8")
107+
else:
108+
yield ("data: " + resp.json() + "\r\n\r\ndata: [DONE]\r\n\r\n").encode(
109+
"utf-8"
110+
)
111+
112+
return cast(
113+
ChatCompletionResponse,
114+
StreamingResponse(event_stream(), media_type="text/event-stream"),
94115
)
95-
response.raise_for_status()
96-
return ChatCompletionResponse.parse_obj(response.json())
97116

98117

99118
@app.get("/docs", response_class=HTMLResponse)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel
4+
5+
from .openai_like_api_models import CompletionUsage, MessageRole
6+
7+
8+
class StreamingMessage(BaseModel):
9+
role: Optional[MessageRole]
10+
content: str
11+
12+
13+
class StreamingChoice(BaseModel):
14+
index: int
15+
message: StreamingMessage
16+
finish_reason: Optional[str]
17+
delta: StreamingMessage
18+
19+
20+
class ChatCompletionStreamingResponse(BaseModel):
21+
id: str
22+
object: str
23+
created: int
24+
25+
choices: list[StreamingChoice]
26+
usage: Optional[CompletionUsage]

0 commit comments

Comments
 (0)