diff --git a/.gitignore b/.gitignore index b471fc8..c026a68 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ node_modules .wrangler .pytest_cache *.jpg -*.json \ No newline at end of file +*.json +*.png \ No newline at end of file diff --git a/README.md b/README.md index 63d8b11..2d48828 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ - 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。 - 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。 - 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。 -- 支持三种负载均衡,默认同时开启。1. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。 +- 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。 - 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。 - 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。 @@ -93,6 +93,17 @@ api_keys: preferences: USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。 AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true + + # 渠道级加权负载均衡配置示例 + - api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo + model: + - gcp1/*: 5 # 冒号后面就是权重,权重仅支持正整数。 + - gcp2/*: 3 # 数字的大小代表权重,数字越大,请求的概率越大。 + - gcp3/*: 2 # 在该示例中,所有渠道加起来一共有 10 个权重,及 10 个请求里面有 5 个请求会请求 gcp1/* 模型,2 个请求会请求 gcp2/* 模型,3 个请求会请求 gcp3/* 模型。 + + preferences: + USE_ROUND_ROBIN: true # 当 USE_ROUND_ROBIN 必须为 true 并且上面的渠道后面没有权重时,会按照原始的渠道顺序请求,如果有权重,会按照加权后的顺序请求。 + AUTO_RETRY: true ``` ## 环境变量 diff --git a/main.py b/main.py index 1b2b73a..4280506 100644 --- a/main.py +++ b/main.py @@ -12,7 +12,7 @@ from models import RequestModel, ImageGenerationRequest from request import get_payload from response import fetch_response, fetch_response_stream -from utils import error_handling_wrapper, post_all_models, load_config +from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder from typing import List, Dict, Union from urllib.parse import urlparse @@ -224,6 +224,29 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest], raise e +def weighted_round_robin(weights): + provider_names = list(weights.keys()) + current_weights = {name: 0 for name in provider_names} + num_selections = total_weight = sum(weights.values()) + weighted_provider_list = [] + + for _ in range(num_selections): + max_ratio = -1 + selected_letter = None + + for name in provider_names: + current_weights[name] += weights[name] + ratio = current_weights[name] / weights[name] + + if ratio > max_ratio: + max_ratio = ratio + selected_letter = name + + weighted_provider_list.append(selected_letter) + current_weights[selected_letter] -= total_weight + + return weighted_provider_list + import asyncio class ModelRequestHandler: def __init__(self): @@ -297,13 +320,31 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques # 检查是否启用轮询 api_index = api_list.index(token) + weights = safe_get(config, 'api_keys', api_index, "weights") + if weights: + # 步骤 1: 提取 matching_providers 中的所有 provider 值 + providers = set(provider['provider'] for provider in matching_providers) + weight_keys = set(weights.keys()) + + # 步骤 3: 计算交集 + intersection = providers.intersection(weight_keys) + weights = dict(filter(lambda item: item[0] in intersection, weights.items())) + weighted_provider_name_list = weighted_round_robin(weights) + new_matching_providers = [] + for provider_name in weighted_provider_name_list: + for provider in matching_providers: + if provider['provider'] == provider_name: + new_matching_providers.append(provider) + matching_providers = new_matching_providers + # import json + # print("matching_providers", json.dumps(matching_providers, indent=4, ensure_ascii=False, default=circular_list_encoder)) + use_round_robin = True auto_retry = True - if config['api_keys'][api_index].get("preferences"): - if config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN") == False: - use_round_robin = False - if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False: - auto_retry = False + if safe_get(config, 'api_keys', api_index, "preferences", "USE_ROUND_ROBIN") == False: + use_round_robin = False + if safe_get(config, 'api_keys', api_index, "preferences", "AUTO_RETRY") == False: + auto_retry = False return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint) diff --git a/test/test_matplotlib.py b/test/test_matplotlib.py new file mode 100644 index 0000000..aea4c9e --- /dev/null +++ b/test/test_matplotlib.py @@ -0,0 +1,49 @@ +import json +import matplotlib.pyplot as plt +from datetime import datetime, timedelta +from collections import defaultdict + +import matplotlib.font_manager as fm +font_path = '/System/Library/Fonts/PingFang.ttc' +prop = fm.FontProperties(fname=font_path) +plt.rcParams['font.family'] = prop.get_name() + +with open('./test/states.json') as f: + data = json.load(f) + request_arrivals = data["request_arrivals"] + +def create_pic(request_arrivals, key): + request_arrivals = request_arrivals[key] + # 将字符串转换为datetime对象 + datetimes = [datetime.fromisoformat(t) for t in request_arrivals] + # 获取最新的时间 + latest_time = max(datetimes) + + # 创建24小时的时间范围 + time_range = [latest_time - timedelta(hours=i) for i in range(24, 0, -1)] + # 统计每小时的请求数 + hourly_counts = defaultdict(int) + for dt in datetimes: + for t in time_range[::-1]: + if dt >= t: + hourly_counts[t] += 1 + break + + # 准备绘图数据 + hours = [t.strftime('%Y-%m-%d %H:00') for t in time_range] + counts = [hourly_counts[t] for t in time_range] + + # 创建柱状图 + plt.figure(figsize=(15, 6)) + plt.bar(hours, counts) + plt.title(f'{key} 端点请求量 (过去24小时)') + plt.xlabel('时间') + plt.ylabel('请求数') + plt.xticks(rotation=45, ha='right') + plt.tight_layout() + + # 保存图片 + plt.savefig(f'{key.replace("/", "")}.png') + +if __name__ == '__main__': + create_pic(request_arrivals, 'POST /v1/chat/completions') \ No newline at end of file diff --git a/test/test_weights.py b/test/test_weights.py new file mode 100644 index 0000000..e1732cb --- /dev/null +++ b/test/test_weights.py @@ -0,0 +1,33 @@ +def weighted_round_robin(weights): + provider_names = list(weights.keys()) + current_weights = {name: 0 for name in provider_names} + num_selections = total_weight = sum(weights.values()) + weighted_provider_list = [] + + for _ in range(num_selections): + max_ratio = -1 + selected_letter = None + + for name in provider_names: + current_weights[name] += weights[name] + ratio = current_weights[name] / weights[name] + + if ratio > max_ratio: + max_ratio = ratio + selected_letter = name + + weighted_provider_list.append(selected_letter) + current_weights[selected_letter] -= total_weight + + return weighted_provider_list + +# 权重和选择次数 +weights = {'a': 5, 'b': 3, 'c': 2} +index = {'a', 'c'} + +result = dict(filter(lambda item: item[0] in index, weights.items())) +print(result) +# result = {k: weights[k] for k in index if k in weights} +# print(result) +weighted_provider_list = weighted_round_robin(weights) +print(weighted_provider_list) diff --git a/utils.py b/utils.py index f27154e..9c2fa42 100644 --- a/utils.py +++ b/utils.py @@ -25,8 +25,25 @@ def update_config(config_data): config_data['providers'][index] = provider api_keys_db = config_data['api_keys'] + + for index, api_key in enumerate(config_data['api_keys']): + weights_dict = {} + models = [] + for model in api_key.get('model'): + if isinstance(model, dict): + key, value = list(model.items())[0] + provider_name = key.split("/")[0] + if "/" in key: + weights_dict.update({provider_name: int(value)}) + models.append(key) + if isinstance(model, str): + models.append(model) + config_data['api_keys'][index]['weights'] = weights_dict + config_data['api_keys'][index]['model'] = models + api_keys_db[index]['model'] = models + api_list = [item["api"] for item in api_keys_db] - # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False)) + # logger.info(json.dumps(config_data, indent=4, ensure_ascii=False, default=circular_list_encoder)) return config_data, api_keys_db, api_list # 读取YAML配置文件 @@ -214,6 +231,12 @@ def get_all_models(config): # us-central1 # europe-west1 # europe-west4 + +def circular_list_encoder(obj): + if isinstance(obj, CircularList): + return obj.to_dict() + raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable') + from collections import deque class CircularList: def __init__(self, items): @@ -226,6 +249,13 @@ def next(self): self.queue.append(item) return item + def to_dict(self): + return { + 'queue': list(self.queue) + } + + + c35s = CircularList(["us-east5", "europe-west1"]) c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"]) c3o = CircularList(["us-east5"]) @@ -256,4 +286,12 @@ def __init__( else: self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3) self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3) - self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3) \ No newline at end of file + self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3) + +def safe_get(data, *keys): + for key in keys: + try: + data = data[key] if isinstance(data, (dict, list)) else data.get(key) + except (KeyError, IndexError, AttributeError, TypeError): + return None + return data \ No newline at end of file