diff --git a/response.py b/response.py index 962c731..8ee5a98 100644 --- a/response.py +++ b/response.py @@ -5,7 +5,7 @@ from log_config import logger -async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None): +async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0): sample_data = { "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc", "object": "chat.completion.chunk", @@ -29,6 +29,10 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]} if role: sample_data["choices"][0]["delta"] = {"role": role, "content": ""} + if total_tokens: + total_tokens = prompt_tokens + completion_tokens + sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,"total_tokens": total_tokens} + sample_data["choices"] = [] json_data = json.dumps(sample_data, ensure_ascii=False) # 构建SSE响应 @@ -68,7 +72,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model): json_data = json.loads( "{" + line + "}") content = json_data.get('text', '') content = "\n".join(content.split("\\n")) - sse_string = await generate_sse_response(timestamp, model, content) + sse_string = await generate_sse_response(timestamp, model, content=content) yield sse_string except json.JSONDecodeError: logger.error(f"无法解析JSON: {line}") @@ -114,7 +118,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod json_data = json.loads( "{" + line + "}") content = json_data.get('text', '') content = "\n".join(content.split("\\n")) - sse_string = await generate_sse_response(timestamp, model, content) + sse_string = await generate_sse_response(timestamp, model, content=content) yield sse_string except json.JSONDecodeError: logger.error(f"无法解析JSON: {line}") @@ -163,6 +167,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model): yield error_message return buffer = "" + input_tokens = 0 async for chunk in response.aiter_text(): # logger.info(f"chunk: {repr(chunk)}") buffer += chunk @@ -171,20 +176,25 @@ async def fetch_claude_response_stream(client, url, headers, payload, model): # logger.info(line) if line.startswith("data:"): - line = line[5:] - if line.startswith(" "): - line = line[1:] + line = line.lstrip("data: ") resp: dict = json.loads(line) message = resp.get("message") if message: - tokens_use = resp.get("usage") role = message.get("role") if role: sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role) yield sse_string + tokens_use = message.get("usage") if tokens_use: - total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"] - # print("\n\rtotal_tokens", total_tokens) + input_tokens = tokens_use.get("input_tokens", 0) + usage = resp.get("usage") + if usage: + output_tokens = usage.get("output_tokens", 0) + total_tokens = input_tokens + output_tokens + sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens) + yield sse_string + # print("\n\rtotal_tokens", total_tokens) + tool_use = resp.get("content_block") tools_id = None function_call_name = None