Skip to content

Commit

Permalink
Fixed the bug where the Claude role could not be obtained and the SSE…
Browse files Browse the repository at this point in the history
… format was incorrect.
  • Loading branch information
yym68686 committed Jul 9, 2024
1 parent 8866240 commit 52bcfe4
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 16 deletions.
12 changes: 12 additions & 0 deletions json_str/gpt/mess_sse.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"!"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" How"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" assist"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":" today"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}],"usage":null}
data: {"id":"chatcmpl-9j98vC0GPtmMAdOsSgh1TGhFQAsZC","object":"chat.completion.chunk","created":1720546933,"model":"gpt-4o-2024-05-13","system_fingerprint":"fp_ce0793330f","choices":[],"usage":{"prompt_tokens":178,"completion_tokens":10,"total_tokens":188}}
49 changes: 38 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import json
import httpx
import logging
import yaml
import traceback
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Depends
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.responses import StreamingResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

Expand Down Expand Up @@ -48,7 +49,16 @@ def load_config():
return []

config = load_config()
# print(config)
for index, provider in enumerate(config):
model_dict = {}
for model in provider['model']:
if type(model) == str:
model_dict[model] = model
if type(model) == dict:
model_dict.update({value: key for key, value in model.items()})
provider['model'] = model_dict
config[index] = provider
# print(json.dumps(config, indent=4, ensure_ascii=False))

async def process_request(request: RequestModel, provider: Dict):
print("provider: ", provider['provider'])
Expand All @@ -64,15 +74,16 @@ async def process_request(request: RequestModel, provider: Dict):

url, headers, payload = await get_payload(request, engine, provider)

request_info = {
"url": url,
"headers": headers,
"payload": payload
}
print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")
# request_info = {
# "url": url,
# "headers": headers,
# "payload": payload
# }
# print(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")

if request.stream:
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, request.model), media_type="text/event-stream")
model = provider['model'][request.model]
return StreamingResponse(fetch_response_stream(app.state.client, url, headers, payload, engine, model), media_type="text/event-stream")
else:
return await fetch_response(app.state.client, url, headers, payload)

Expand All @@ -81,7 +92,11 @@ def __init__(self):
self.last_provider_index = -1

def get_matching_providers(self, model_name):
return [provider for provider in config if model_name in provider['model']]
# for provider in config:
# print("provider", model_name, list(provider['model'].keys()))
# if model_name in provider['model'].keys():
# print("provider", provider)
return [provider for provider in config if model_name in provider['model'].keys()]

async def request_model(self, request: RequestModel, token: str):
model_name = request.model
Expand Down Expand Up @@ -122,6 +137,18 @@ async def try_all_providers(self, request: RequestModel, providers: List[Dict],

model_handler = ModelRequestHandler()

@app.middleware("http")
async def log_requests(request: Request, call_next):
# 打印请求信息
logging.info(f"Request: {request.method} {request.url}")
# 打印请求体(如果有)
if request.method in ["POST", "PUT", "PATCH"]:
body = await request.body()
logging.info(f"Request Body: {body.decode('utf-8')}")

response = await call_next(request)
return response

def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials
if token not in api_keys_db:
Expand All @@ -137,7 +164,7 @@ def get_all_models():
unique_models = set()

for provider in config:
for model in provider['model']:
for model in provider['model'].keys():
if model not in unique_models:
unique_models.add(model)
model_info = {
Expand Down
10 changes: 6 additions & 4 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ async def get_gemini_payload(request, engine, provider):
'Content-Type': 'application/json'
}
url = provider['base_url']
model = provider['model'][request.model]
if request.stream:
gemini_stream = "streamGenerateContent"
url = url.format(model=request.model, stream=gemini_stream, api_key=provider['api'])
url = url.format(model=model, stream=gemini_stream, api_key=provider['api'])

messages = []
for msg in request.messages:
Expand Down Expand Up @@ -112,7 +113,6 @@ async def get_gpt_payload(request, engine, provider):
'Content-Type': 'application/json'
}
url = provider['base_url']
url = url.format(model=request.model, stream=request.stream, api_key=provider['api'])

messages = []
for msg in request.messages:
Expand All @@ -133,8 +133,9 @@ async def get_gpt_payload(request, engine, provider):
else:
messages.append({"role": msg.role, "content": content})

model = provider['model'][request.model]
payload = {
"model": request.model,
"model": model,
"messages": messages,
}

Expand Down Expand Up @@ -222,8 +223,9 @@ async def get_claude_payload(request, engine, provider):
elif msg.role == "system":
system_prompt = content

model = provider['model'][request.model]
payload = {
"model": request.model,
"model": model,
"messages": messages,
"system": system_prompt,
}
Expand Down
8 changes: 7 additions & 1 deletion response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import httpx

async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None):
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
sample_data = {
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
"object": "chat.completion.chunk",
Expand All @@ -24,6 +24,8 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
if tools_id and function_call_name:
sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"id":tools_id,"type":"function","function":{"name":function_call_name,"arguments":""}}]}
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
if role:
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
json_data = json.dumps(sample_data, ensure_ascii=False)

# 构建SSE响应
Expand Down Expand Up @@ -91,6 +93,10 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
message = resp.get("message")
if message:
tokens_use = resp.get("usage")
role = message.get("role")
if role:
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
yield sse_string
if tokens_use:
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
# print("\n\rtotal_tokens", total_tokens)
Expand Down

0 comments on commit 52bcfe4

Please sign in to comment.