Skip to content

Commit

Permalink
✨ Feature: Add feature: Add support for weighted load balancing.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 5, 2024
1 parent 73a667f commit 3ec7a0b
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ node_modules
.wrangler
.pytest_cache
*.jpg
*.json
*.json
*.png
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
- 支持三种负载均衡,默认同时开启。1. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。
- 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。

Expand Down Expand Up @@ -93,6 +93,17 @@ api_keys:
preferences:
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true

# 渠道级加权负载均衡配置示例
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
model:
- gcp1/*: 5 # 冒号后面就是权重,权重仅支持正整数。
- gcp2/*: 3 # 数字的大小代表权重,数字越大,请求的概率越大。
- gcp3/*: 2 # 在该示例中,所有渠道加起来一共有 10 个权重,及 10 个请求里面有 5 个请求会请求 gcp1/* 模型,2 个请求会请求 gcp2/* 模型,3 个请求会请求 gcp3/* 模型。

preferences:
USE_ROUND_ROBIN: true # 当 USE_ROUND_ROBIN 必须为 true 并且上面的渠道后面没有权重时,会按照原始的渠道顺序请求,如果有权重,会按照加权后的顺序请求。
AUTO_RETRY: true
```
## 环境变量
Expand Down
53 changes: 47 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from models import RequestModel, ImageGenerationRequest
from request import get_payload
from response import fetch_response, fetch_response_stream
from utils import error_handling_wrapper, post_all_models, load_config
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder

from typing import List, Dict, Union
from urllib.parse import urlparse
Expand Down Expand Up @@ -224,6 +224,29 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],

raise e

def weighted_round_robin(weights):
provider_names = list(weights.keys())
current_weights = {name: 0 for name in provider_names}
num_selections = total_weight = sum(weights.values())
weighted_provider_list = []

for _ in range(num_selections):
max_ratio = -1
selected_letter = None

for name in provider_names:
current_weights[name] += weights[name]
ratio = current_weights[name] / weights[name]

if ratio > max_ratio:
max_ratio = ratio
selected_letter = name

weighted_provider_list.append(selected_letter)
current_weights[selected_letter] -= total_weight

return weighted_provider_list

import asyncio
class ModelRequestHandler:
def __init__(self):
Expand Down Expand Up @@ -297,13 +320,31 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques

# 检查是否启用轮询
api_index = api_list.index(token)
weights = safe_get(config, 'api_keys', api_index, "weights")
if weights:
# 步骤 1: 提取 matching_providers 中的所有 provider 值
providers = set(provider['provider'] for provider in matching_providers)
weight_keys = set(weights.keys())

# 步骤 3: 计算交集
intersection = providers.intersection(weight_keys)
weights = dict(filter(lambda item: item[0] in intersection, weights.items()))
weighted_provider_name_list = weighted_round_robin(weights)
new_matching_providers = []
for provider_name in weighted_provider_name_list:
for provider in matching_providers:
if provider['provider'] == provider_name:
new_matching_providers.append(provider)
matching_providers = new_matching_providers
# import json
# print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False, default=circular_list_encoder))

use_round_robin = True
auto_retry = True
if config['api_keys'][api_index].get("preferences"):
if config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN") == False:
use_round_robin = False
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
auto_retry = False
if safe_get(config, 'api_keys', api_index, "preferences", "USE_ROUND_ROBIN") == False:
use_round_robin = False
if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False:
auto_retry = False

return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)

Expand Down
49 changes: 49 additions & 0 deletions test/test_matplotlib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from collections import defaultdict

import matplotlib.font_manager as fm
font_path = '/System/Library/Fonts/PingFang.ttc'
prop = fm.FontProperties(fname=font_path)
plt.rcParams['font.family'] = prop.get_name()

with open('./test/states.json') as f:
data = json.load(f)
request_arrivals = data["request_arrivals"]

def create_pic(request_arrivals, key):
request_arrivals = request_arrivals[key]
# 将字符串转换为datetime对象
datetimes = [datetime.fromisoformat(t) for t in request_arrivals]
# 获取最新的时间
latest_time = max(datetimes)

# 创建24小时的时间范围
time_range = [latest_time - timedelta(hours=i) for i in range(24, 0, -1)]
# 统计每小时的请求数
hourly_counts = defaultdict(int)
for dt in datetimes:
for t in time_range[::-1]:
if dt >= t:
hourly_counts[t] += 1
break

# 准备绘图数据
hours = [t.strftime('%Y-%m-%d %H:00') for t in time_range]
counts = [hourly_counts[t] for t in time_range]

# 创建柱状图
plt.figure(figsize=(15, 6))
plt.bar(hours, counts)
plt.title(f'{key} 端点请求量 (过去24小时)')
plt.xlabel('时间')
plt.ylabel('请求数')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()

# 保存图片
plt.savefig(f'{key.replace("/", "")}.png')

if __name__ == '__main__':
create_pic(request_arrivals, 'POST /v1/chat/completions')
33 changes: 33 additions & 0 deletions test/test_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
def weighted_round_robin(weights):
provider_names = list(weights.keys())
current_weights = {name: 0 for name in provider_names}
num_selections = total_weight = sum(weights.values())
weighted_provider_list = []

for _ in range(num_selections):
max_ratio = -1
selected_letter = None

for name in provider_names:
current_weights[name] += weights[name]
ratio = current_weights[name] / weights[name]

if ratio > max_ratio:
max_ratio = ratio
selected_letter = name

weighted_provider_list.append(selected_letter)
current_weights[selected_letter] -= total_weight

return weighted_provider_list

# 权重和选择次数
weights = {'a': 5, 'b': 3, 'c': 2}
index = {'a', 'c'}

result = dict(filter(lambda item: item[0] in index, weights.items()))
print(result)
# result = {k: weights[k] for k in index if k in weights}
# print(result)
weighted_provider_list = weighted_round_robin(weights)
print(weighted_provider_list)
42 changes: 40 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,25 @@ def update_config(config_data):
config_data['providers'][index] = provider

api_keys_db = config_data['api_keys']

for index, api_key in enumerate(config_data['api_keys']):
weights_dict = {}
models = []
for model in api_key.get('model'):
if isinstance(model, dict):
key, value = list(model.items())[0]
provider_name = key.split("/")[0]
if "/" in key:
weights_dict.update({provider_name: int(value)})
models.append(key)
if isinstance(model, str):
models.append(model)
config_data['api_keys'][index]['weights'] = weights_dict
config_data['api_keys'][index]['model'] = models
api_keys_db[index]['model'] = models

api_list = [item["api"] for item in api_keys_db]
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False))
# logger.info(json.dumps(config_data, indent=4, ensure_ascii=False, default=circular_list_encoder))
return config_data, api_keys_db, api_list

# 读取YAML配置文件
Expand Down Expand Up @@ -214,6 +231,12 @@ def get_all_models(config):
# us-central1
# europe-west1
# europe-west4

def circular_list_encoder(obj):
if isinstance(obj, CircularList):
return obj.to_dict()
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')

from collections import deque
class CircularList:
def __init__(self, items):
Expand All @@ -226,6 +249,13 @@ def next(self):
self.queue.append(item)
return item

def to_dict(self):
return {
'queue': list(self.queue)
}



c35s = CircularList(["us-east5", "europe-west1"])
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
c3o = CircularList(["us-east5"])
Expand Down Expand Up @@ -256,4 +286,12 @@ def __init__(
else:
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)

def safe_get(data, *keys):
for key in keys:
try:
data = data[key] if isinstance(data, (dict, list)) else data.get(key)
except (KeyError, IndexError, AttributeError, TypeError):
return None
return data

0 comments on commit 3ec7a0b

Please sign in to comment.