Skip to content

Commit

Permalink
Add feature: support vertex claude API using tool use functionality.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 1, 2024
1 parent 3b159d8 commit 7d44776
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 91 deletions.
52 changes: 49 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from contextlib import asynccontextmanager

from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException, Depends
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

Expand Down Expand Up @@ -40,6 +40,37 @@ async def lifespan(app: FastAPI):

app = FastAPI(lifespan=lifespan)

# from time import time
# from collections import defaultdict
# import asyncio

# class StatsMiddleware:
# def __init__(self):
# self.request_counts = defaultdict(int)
# self.request_times = defaultdict(float)
# self.ip_counts = defaultdict(lambda: defaultdict(int))
# self.lock = asyncio.Lock()

# async def __call__(self, request: Request, call_next):
# start_time = time()
# response = await call_next(request)
# process_time = time() - start_time

# endpoint = f"{request.method} {request.url.path}"
# client_ip = request.client.host

# async with self.lock:
# self.request_counts[endpoint] += 1
# self.request_times[endpoint] += process_time
# self.ip_counts[endpoint][client_ip] += 1

# return response
# # 创建 StatsMiddleware 实例
# stats_middleware = StatsMiddleware()

# # 添加 StatsMiddleware
# app.add_middleware(StatsMiddleware)

# 配置 CORS 中间件
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -219,9 +250,24 @@ def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(32)
return JSONResponse(content={"api_key": api_key})

# @app.get("/stats")
# async def get_stats(token: str = Depends(verify_api_key)):
# async with stats_middleware.lock:
# return {
# "request_counts": dict(stats_middleware.request_counts),
# "average_request_times": {
# endpoint: total_time / count
# for endpoint, total_time in stats_middleware.request_times.items()
# for count in [stats_middleware.request_counts[endpoint]]
# },
# "ip_counts": {
# endpoint: dict(ips)
# for endpoint, ips in stats_middleware.ip_counts.items()
# }
# }

# async def on_fetch(request, env):
# import asgi

# return await asgi.fetch(app, request, env)

if __name__ == '__main__':
Expand All @@ -232,5 +278,5 @@ def generate_api_key():
port=8000,
reload=True,
ws="none",
log_level="warning"
# log_level="warning"
)
143 changes: 58 additions & 85 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ async def get_vertex_gemini_payload(request, engine, provider):

async def get_vertex_claude_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
'Content-Type': 'application/json',
}
if provider.get("client_email") and provider.get("private_key"):
access_token = get_access_token(provider['client_email'], provider['private_key'])
Expand All @@ -386,12 +386,10 @@ async def get_vertex_claude_payload(request, engine, provider):
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)

messages = []
systemInstruction = None
function_arguments = None
system_prompt = None
for msg in request.messages:
if msg.role == "assistant":
msg.role = "model"
tool_calls = None
tool_call_id = None
if isinstance(msg.content, list):
content = []
for item in msg.content:
Expand All @@ -402,109 +400,84 @@ async def get_vertex_claude_payload(request, engine, provider):
image_message = await get_image_message(item.image_url.url, engine)
content.append(image_message)
else:
content = [{"text": msg.content}]
content = msg.content
tool_calls = msg.tool_calls
tool_call_id = msg.tool_call_id

if tool_calls:
tool_call = tool_calls[0]
function_arguments = {
"functionCall": {
tool_calls_list = []
for tool_call in tool_calls:
tool_calls_list.append({
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"args": json.loads(tool_call.function.arguments)
}
}
messages.append(
{
"role": "model",
"parts": [function_arguments]
}
)
elif msg.role == "tool":
function_call_name = function_arguments["functionCall"]["name"]
messages.append(
{
"role": "function",
"parts": [{
"functionResponse": {
"name": function_call_name,
"response": {
"name": function_call_name,
"content": {
"result": msg.content,
}
}
}
}]
}
)
"input": json.loads(tool_call.function.arguments),
})
messages.append({"role": msg.role, "content": tool_calls_list})
elif tool_call_id:
messages.append({"role": "user", "content": [{
"type": "tool_result",
"tool_use_id": tool_call.id,
"content": content
}]})
elif msg.role != "system":
messages.append({"role": msg.role, "parts": content})
messages.append({"role": msg.role, "content": content})
elif msg.role == "system":
systemInstruction = {"parts": content}
system_prompt = content

conversation_len = len(messages) - 1
message_index = 0
while message_index < conversation_len:
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
if messages[message_index].get("content"):
if isinstance(messages[message_index]["content"], list):
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
content_list.extend(messages[message_index + 1]["content"])
messages[message_index]["content"] = content_list
else:
messages[message_index]["content"] += messages[message_index + 1]["content"]
messages.pop(message_index + 1)
conversation_len = conversation_len - 1
else:
message_index = message_index + 1

model = provider['model'][request.model]
payload = {
"contents": messages,
# "safetySettings": [
# {
# "category": "HARM_CATEGORY_HARASSMENT",
# "threshold": "BLOCK_NONE"
# },
# {
# "category": "HARM_CATEGORY_HATE_SPEECH",
# "threshold": "BLOCK_NONE"
# },
# {
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
# "threshold": "BLOCK_NONE"
# },
# {
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
# "threshold": "BLOCK_NONE"
# }
# ]
"generationConfig": {
"temperature": 0.5,
"max_output_tokens": 8192,
"top_k": 40,
"top_p": 0.95
},
"anthropic_version": "vertex-2023-10-16",
"messages": messages,
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
}
if systemInstruction:
payload["system_instruction"] = systemInstruction

miss_fields = [
'model',
'messages',
'stream',
'tool_choice',
'temperature',
'top_p',
'max_tokens',
'presence_penalty',
'frequency_penalty',
'n',
'user',
'include_usage',
'logprobs',
'top_logprobs'
]

for field, value in request.model_dump(exclude_unset=True).items():
if field not in miss_fields and value is not None:
if field == "tools":
payload.update({
"tools": [{
"function_declarations": [tool["function"] for tool in value]
}],
"tool_config": {
"function_calling_config": {
"mode": "AUTO"
}
}
})
else:
payload[field] = value
payload[field] = value

if request.tools and provider.get("tools"):
tools = []
for tool in request.tools:
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
tools.append(json_tool)
payload["tools"] = tools
if "tool_choice" in payload:
payload["tool_choice"] = {
"type": "auto"
}

if provider.get("tools") == False:
payload.pop("tools", None)
payload.pop("tool_choice", None)

return url, headers, payload

Expand Down
53 changes: 51 additions & 2 deletions response.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,55 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
yield sse_string

async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
timestamp = datetime.timestamp(datetime.now())
async with client.stream('POST', url, headers=headers, json=payload) as response:
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
buffer = ""
revicing_function_call = False
function_full_response = "{"
need_function_call = False
async for chunk in response.aiter_text():
buffer += chunk
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
logger.info(f"{line}")
if line and '\"text\": \"' in line:
try:
json_data = json.loads( "{" + line + "}")
content = json_data.get('text', '')
content = "\n".join(content.split("\\n"))
sse_string = await generate_sse_response(timestamp, model, content)
yield sse_string
except json.JSONDecodeError:
logger.error(f"无法解析JSON: {line}")

if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
revicing_function_call = True
need_function_call = True
if ']' in line:
revicing_function_call = False
continue

function_full_response += line

if need_function_call:
function_call = json.loads(function_full_response)
function_call_name = function_call["name"]
function_call_id = function_call["id"]
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
yield sse_string
function_full_response = json.dumps(function_call["input"])
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
yield sse_string

async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
redirect_count = 0
while redirect_count < max_redirects:
Expand Down Expand Up @@ -202,10 +251,10 @@ async def fetch_response(client, url, headers, payload):

async def fetch_response_stream(client, url, headers, payload, engine, model):
try:
if engine == "gemini" or engine == "vertex":
if engine == "gemini" or (engine == "vertex" and "gemini" in model):
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
yield chunk
elif engine == "claude":
elif engine == "claude" or (engine == "vertex" and "claude" in model):
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
yield chunk
elif engine == "gpt":
Expand Down
2 changes: 1 addition & 1 deletion test/provider_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_request_model(test_client, api_key, get_model):

response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
for line in response.iter_lines():
print(line)
print(line.lstrip("data: "))
assert response.status_code == 200

if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def error_handling_wrapper(generator, status_code=200):
try:
first_item = await generator.__anext__()
first_item_str = first_item
# logger.info("first_item_str: %s", first_item_str)
if isinstance(first_item_str, (bytes, bytearray)):
first_item_str = first_item_str.decode("utf-8")
if isinstance(first_item_str, str):
Expand Down

0 comments on commit 7d44776

Please sign in to comment.