diff --git a/main.py b/main.py index e84998f..7c22d03 100644 --- a/main.py +++ b/main.py @@ -275,20 +275,26 @@ def get_matching_providers(self, model_name, token): for model in config['api_keys'][api_index]['model']: if "/" in model: provider_name = model.split("/")[0] - model = model.split("/")[1] + model_name_split = "/".join(model.split("/")[1:]) models_list = [] for provider in config['providers']: if provider['provider'] == provider_name: models_list.extend(list(provider['model'].keys())) # print("models_list", models_list) # print("model_name", model_name) + + # 处理带斜杠的模型名 + for provider in config['providers']: + if model in provider['model'].keys(): + provider_rules.append(provider['provider'] + "/" + model) + # print("model", model) - if (model and model_name in models_list) or (model == "*" and model_name in models_list): + if (model_name_split and model_name in models_list) or (model_name_split == "*" and model_name in models_list): provider_rules.append(provider_name) else: for provider in config['providers']: if model in provider['model'].keys(): - provider_rules.append(provider['provider'] + "/" + model) + provider_rules.append(provider['provider'] + "/" + model_name_split) provider_list = [] # print("provider_rules", provider_rules) @@ -297,7 +303,7 @@ def get_matching_providers(self, model_name, token): # print("provider", provider, provider['provider'] == item, item) if "/" in item: if provider['provider'] == item.split("/")[0]: - if model_name in provider['model'].keys() and item.split("/")[1] == model_name: + if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name: provider_list.append(provider) elif provider['provider'] == item: if model_name in provider['model'].keys(): @@ -422,15 +428,13 @@ async def is_rate_limited(self, key: str, limit: int, period: int) -> bool: rate_limiter = InMemoryRateLimiter() -async def get_user_rate_limit(token: str = None): +async def get_user_rate_limit(api_index: str = None): # 这里应该实现根据 token 获取用户速率限制的逻辑 # 示例: 返回 (次数, 秒数) config = app.state.config - api_list = app.state.api_list - api_index = api_list.index(token) raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT") - if not token or not raw_rate_limit: + if not api_index or not raw_rate_limit: return (60, 60) rate_limit = parse_rate_limit(raw_rate_limit) @@ -439,8 +443,14 @@ async def get_user_rate_limit(token: str = None): security = HTTPBearer() async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)): token = credentials.credentials if credentials else None - # print("token", token) - limit, period = await get_user_rate_limit(token) + api_list = app.state.api_list + try: + api_index = api_list.index(token) + except ValueError: + print("error: Invalid or missing API Key:", token) + api_index = None + token = None + limit, period = await get_user_rate_limit(api_index) # 使用 IP 地址和 token(如果有)作为限制键 client_ip = request.client.host