From 3b159d8a82fbec9f3b704f505b70bce6c3d12d05 Mon Sep 17 00:00:00 2001 From: yym68686 Date: Sun, 1 Sep 2024 02:36:00 +0800 Subject: [PATCH] Add Gemini region load balancing. --- README.md | 6 +- request.py | 163 +++++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 5 +- utils.py | 37 ++++++++++- 4 files changed, 201 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 288b919..6861828 100644 --- a/README.md +++ b/README.md @@ -55,10 +55,14 @@ providers: - provider: vertex project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。 private_key: "-----BEGIN PRIVATE KEY-----\nxxxxx\n-----END PRIVATE" # 描述: Google Cloud Vertex AI服务账号的私钥。格式: 一个JSON格式的字符串,包含服务账号的私钥信息。获取方式: 在Google Cloud Console中创建服务账号,生成JSON格式的密钥文件,然后将其内容设置为此环境变量的值。 - client_email: xxxxxxxxxx@developer.gserviceaccount.com # 描述: Google Cloud Vertex AI服务账号的电子邮件地址。格式: 通常是形如 "service-account-name@project-id.iam.gserviceaccount.com" 的字符串。获取方式: 在创建服务账号时生成,也可以在Google Cloud Console的"IAM与管理"部分查看服务账号详情获得。 + client_email: xxxxxxxxxx@xxxxxxx.gserviceaccount.com # 描述: Google Cloud Vertex AI服务账号的电子邮件地址。格式: 通常是形如 "service-account-name@project-id.iam.gserviceaccount.com" 的字符串。获取方式: 在创建服务账号时生成,也可以在Google Cloud Console的"IAM与管理"部分查看服务账号详情获得。 model: - gemini-1.5-pro - gemini-1.5-flash + - claude-3-5-sonnet@20240620: claude-3-5-sonnet + - claude-3-opus@20240229: claude-3-opus + - claude-3-sonnet@20240229: claude-3-sonnet + - claude-3-haiku@20240307: claude-3-haiku tools: true - provider: other-provider diff --git a/request.py b/request.py index 215cd8b..5ce9d1d 100644 --- a/request.py +++ b/request.py @@ -1,6 +1,6 @@ import json from models import RequestModel -from log_config import logger +from utils import c35s, c3s, c3o, c3h, CircularList async def get_image_message(base64_image, engine = None): if "gpt" == engine: @@ -222,19 +222,168 @@ def get_access_token(client_email, private_key): response.raise_for_status() return response.json()["access_token"] -async def get_vertex_payload(request, engine, provider): +async def get_vertex_gemini_payload(request, engine, provider): headers = { 'Content-Type': 'application/json' } if provider.get("client_email") and provider.get("private_key"): access_token = get_access_token(provider['client_email'], provider['private_key']) headers['Authorization'] = f"Bearer {access_token}" - model = provider['model'][request.model] + if provider.get("project_id"): + project_id = provider.get("project_id") + if request.stream: gemini_stream = "streamGenerateContent" + model = provider['model'][request.model] + location = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"]) + url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream) + + messages = [] + systemInstruction = None + function_arguments = None + for msg in request.messages: + if msg.role == "assistant": + msg.role = "model" + tool_calls = None + if isinstance(msg.content, list): + content = [] + for item in msg.content: + if item.type == "text": + text_message = await get_text_message(msg.role, item.text, engine) + content.append(text_message) + elif item.type == "image_url": + image_message = await get_image_message(item.image_url.url, engine) + content.append(image_message) + else: + content = [{"text": msg.content}] + tool_calls = msg.tool_calls + + if tool_calls: + tool_call = tool_calls[0] + function_arguments = { + "functionCall": { + "name": tool_call.function.name, + "args": json.loads(tool_call.function.arguments) + } + } + messages.append( + { + "role": "model", + "parts": [function_arguments] + } + ) + elif msg.role == "tool": + function_call_name = function_arguments["functionCall"]["name"] + messages.append( + { + "role": "function", + "parts": [{ + "functionResponse": { + "name": function_call_name, + "response": { + "name": function_call_name, + "content": { + "result": msg.content, + } + } + } + }] + } + ) + elif msg.role != "system": + messages.append({"role": msg.role, "parts": content}) + elif msg.role == "system": + systemInstruction = {"parts": content} + + + payload = { + "contents": messages, + # "safetySettings": [ + # { + # "category": "HARM_CATEGORY_HARASSMENT", + # "threshold": "BLOCK_NONE" + # }, + # { + # "category": "HARM_CATEGORY_HATE_SPEECH", + # "threshold": "BLOCK_NONE" + # }, + # { + # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + # "threshold": "BLOCK_NONE" + # }, + # { + # "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + # "threshold": "BLOCK_NONE" + # } + # ] + "generationConfig": { + "temperature": 0.5, + "max_output_tokens": 8192, + "top_k": 40, + "top_p": 0.95 + }, + } + if systemInstruction: + payload["system_instruction"] = systemInstruction + + miss_fields = [ + 'model', + 'messages', + 'stream', + 'tool_choice', + 'temperature', + 'top_p', + 'max_tokens', + 'presence_penalty', + 'frequency_penalty', + 'n', + 'user', + 'include_usage', + 'logprobs', + 'top_logprobs' + ] + + for field, value in request.model_dump(exclude_unset=True).items(): + if field not in miss_fields and value is not None: + if field == "tools": + payload.update({ + "tools": [{ + "function_declarations": [tool["function"] for tool in value] + }], + "tool_config": { + "function_calling_config": { + "mode": "AUTO" + } + } + }) + else: + payload[field] = value + + return url, headers, payload + +async def get_vertex_claude_payload(request, engine, provider): + headers = { + 'Content-Type': 'application/json' + } + if provider.get("client_email") and provider.get("private_key"): + access_token = get_access_token(provider['client_email'], provider['private_key']) + headers['Authorization'] = f"Bearer {access_token}" if provider.get("project_id"): project_id = provider.get("project_id") - url = "https://us-central1-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/us-central1/publishers/google/models/{MODEL_ID}:{stream}".format(PROJECT_ID=project_id, MODEL_ID=model, stream=gemini_stream) + + model = provider['model'][request.model] + if "claude-3-5-sonnet" in model: + location = c35s + elif "claude-3-opus" in model: + location = c3o + elif "claude-3-sonnet" in model: + location = c3s + elif "claude-3-haiku" in model: + location = c3h + + if request.stream: + claude_stream = "streamRawPredict" + url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream) messages = [] systemInstruction = None @@ -620,8 +769,10 @@ async def get_claude_payload(request, engine, provider): async def get_payload(request: RequestModel, engine, provider): if engine == "gemini": return await get_gemini_payload(request, engine, provider) - elif engine == "vertex": - return await get_vertex_payload(request, engine, provider) + elif engine == "vertex" and "gemini" in provider['model'][request.model]: + return await get_vertex_gemini_payload(request, engine, provider) + elif engine == "vertex" and "claude" in provider['model'][request.model]: + return await get_vertex_claude_payload(request, engine, provider) elif engine == "claude": return await get_claude_payload(request, engine, provider) elif engine == "gpt": diff --git a/requirements.txt b/requirements.txt index b9da9a1..4031709 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -httpx[http2] pyyaml pytest uvicorn -fastapi \ No newline at end of file +fastapi +httpx[http2] +cryptography \ No newline at end of file diff --git a/utils.py b/utils.py index bd588e2..bcc8e0f 100644 --- a/utils.py +++ b/utils.py @@ -185,4 +185,39 @@ def get_all_models(config): } all_models.append(model_info) - return all_models \ No newline at end of file + return all_models + +# 【GCP-Vertex AI 目前有這些區域可用】 https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude?hl=zh_cn +# c3.5s +# us-east5 +# europe-west1 + +# c3s +# us-east5 +# us-central1 +# asia-southeast1 + +# c3o +# us-east5 + +# c3h +# us-east5 +# us-central1 +# europe-west1 +# europe-west4 +from collections import deque +class CircularList: + def __init__(self, items): + self.queue = deque(items) + + def next(self): + if not self.queue: + return None + item = self.queue.popleft() + self.queue.append(item) + return item + +c35s = CircularList(["us-east5", "europe-west1"]) +c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"]) +c3o = CircularList(["us-east5"]) +c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"]) \ No newline at end of file