Skip to content

Commit

Permalink
🐛 Bug: 1. Fix the bug where the API key is not found when rate limiting.
Browse files Browse the repository at this point in the history
2. Fix the bug where the characters before the slash in the model name with a slash are parsed as the channel name.
  • Loading branch information
yym68686 committed Sep 10, 2024
1 parent 95ca783 commit cfd3f47
Showing 1 changed file with 20 additions and 10 deletions.
30 changes: 20 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,26 @@ def get_matching_providers(self, model_name, token):
for model in config['api_keys'][api_index]['model']:
if "/" in model:
provider_name = model.split("/")[0]
model = model.split("/")[1]
model_name_split = "/".join(model.split("/")[1:])
models_list = []
for provider in config['providers']:
if provider['provider'] == provider_name:
models_list.extend(list(provider['model'].keys()))
# print("models_list", models_list)
# print("model_name", model_name)

# 处理带斜杠的模型名
for provider in config['providers']:
if model in provider['model'].keys():
provider_rules.append(provider['provider'] + "/" + model)

# print("model", model)
if (model and model_name in models_list) or (model == "*" and model_name in models_list):
if (model_name_split and model_name in models_list) or (model_name_split == "*" and model_name in models_list):
provider_rules.append(provider_name)
else:
for provider in config['providers']:
if model in provider['model'].keys():
provider_rules.append(provider['provider'] + "/" + model)
provider_rules.append(provider['provider'] + "/" + model_name_split)

provider_list = []
# print("provider_rules", provider_rules)
Expand All @@ -297,7 +303,7 @@ def get_matching_providers(self, model_name, token):
# print("provider", provider, provider['provider'] == item, item)
if "/" in item:
if provider['provider'] == item.split("/")[0]:
if model_name in provider['model'].keys() and item.split("/")[1] == model_name:
if model_name in provider['model'].keys() and "/".join(item.split("/")[1:]) == model_name:
provider_list.append(provider)
elif provider['provider'] == item:
if model_name in provider['model'].keys():
Expand Down Expand Up @@ -422,15 +428,13 @@ async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:

rate_limiter = InMemoryRateLimiter()

async def get_user_rate_limit(token: str = None):
async def get_user_rate_limit(api_index: str = None):
# 这里应该实现根据 token 获取用户速率限制的逻辑
# 示例: 返回 (次数, 秒数)
config = app.state.config
api_list = app.state.api_list
api_index = api_list.index(token)
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")

if not token or not raw_rate_limit:
if not api_index or not raw_rate_limit:
return (60, 60)

rate_limit = parse_rate_limit(raw_rate_limit)
Expand All @@ -439,8 +443,14 @@ async def get_user_rate_limit(token: str = None):
security = HTTPBearer()
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials if credentials else None
# print("token", token)
limit, period = await get_user_rate_limit(token)
api_list = app.state.api_list
try:
api_index = api_list.index(token)
except ValueError:
print("error: Invalid or missing API Key:", token)
api_index = None
token = None
limit, period = await get_user_rate_limit(api_index)

# 使用 IP 地址和 token(如果有)作为限制键
client_ip = request.client.host
Expand Down

0 comments on commit cfd3f47

Please sign in to comment.