Skip to content

Commit

Permalink
🐛 Bug: Fix the bug where the official Claude API does not correctly p…
Browse files Browse the repository at this point in the history
…ass the token count.
  • Loading branch information
yym68686 committed Sep 9, 2024
1 parent 9874f60 commit 1de140d
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from log_config import logger


async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
sample_data = {
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
"object": "chat.completion.chunk",
Expand All @@ -29,6 +29,10 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
if role:
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
if total_tokens:
total_tokens = prompt_tokens + completion_tokens
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,"total_tokens": total_tokens}
sample_data["choices"] = []
json_data = json.dumps(sample_data, ensure_ascii=False)

# 构建SSE响应
Expand Down Expand Up @@ -68,7 +72,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
json_data = json.loads( "{" + line + "}")
content = json_data.get('text', '')
content = "\n".join(content.split("\\n"))
sse_string = await generate_sse_response(timestamp, model, content)
sse_string = await generate_sse_response(timestamp, model, content=content)
yield sse_string
except json.JSONDecodeError:
logger.error(f"无法解析JSON: {line}")
Expand Down Expand Up @@ -114,7 +118,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
json_data = json.loads( "{" + line + "}")
content = json_data.get('text', '')
content = "\n".join(content.split("\\n"))
sse_string = await generate_sse_response(timestamp, model, content)
sse_string = await generate_sse_response(timestamp, model, content=content)
yield sse_string
except json.JSONDecodeError:
logger.error(f"无法解析JSON: {line}")
Expand Down Expand Up @@ -163,6 +167,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
yield error_message
return
buffer = ""
input_tokens = 0
async for chunk in response.aiter_text():
# logger.info(f"chunk: {repr(chunk)}")
buffer += chunk
Expand All @@ -171,20 +176,25 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
# logger.info(line)

if line.startswith("data:"):
line = line[5:]
if line.startswith(" "):
line = line[1:]
line = line.lstrip("data: ")
resp: dict = json.loads(line)
message = resp.get("message")
if message:
tokens_use = resp.get("usage")
role = message.get("role")
if role:
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
yield sse_string
tokens_use = message.get("usage")
if tokens_use:
total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
# print("\n\rtotal_tokens", total_tokens)
input_tokens = tokens_use.get("input_tokens", 0)
usage = resp.get("usage")
if usage:
output_tokens = usage.get("output_tokens", 0)
total_tokens = input_tokens + output_tokens
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
yield sse_string
# print("\n\rtotal_tokens", total_tokens)

tool_use = resp.get("content_block")
tools_id = None
function_call_name = None
Expand Down

0 comments on commit 1de140d

Please sign in to comment.