Skip to content

Commit

Permalink
✨ Feature: Add feature: Add support for rate limiting.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 6, 2024
1 parent 7477ff7 commit b812da1
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 8 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
- 支持四种负载均衡。1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
- 支持四种负载均衡。
1. 支持渠道级加权负载均衡,可以根据不同的渠道权重分配请求。默认不开启,需要配置渠道权重。
2. 支持 Vertex 区域级负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。自动开启不需要额外配置。
3. 除了 Vertex 区域级负载均衡,所有 API 均支持渠道级顺序负载均衡,提高沉浸式翻译体验。自动开启不需要额外配置。
4. 支持单个渠道多个 API Key 自动开启 API key 级别的轮训负载均衡。
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
- 支持限流,可以设置每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min。

## Configuration

Expand Down Expand Up @@ -93,6 +98,7 @@ api_keys:
preferences:
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
RATE_LIMIT: 2/min # 支持限流,每分钟最多请求次数,可以设置为整数,如 2/min,2 次每分钟、5/hour,5 次每小时、10/day,10 次每天,10/month,10 次每月,10/year,10 次每年。默认60/min,选填

# 渠道级加权负载均衡配置示例
- api: sk-KjjI60Yf0JFWtxxxxxxxxxxxxxxwmRWpWpQRo
Expand Down
82 changes: 75 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from log_config import logger

import re
import httpx
import secrets
import time as time_module
from contextlib import asynccontextmanager

from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -14,6 +16,7 @@
from response import fetch_response, fetch_response_stream
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder

from collections import defaultdict
from typing import List, Dict, Union
from urllib.parse import urlparse

Expand Down Expand Up @@ -374,8 +377,73 @@ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRe

model_handler = ModelRequestHandler()

# 安全性依赖
def parse_rate_limit(limit_string):
# 定义时间单位到秒的映射
time_units = {
's': 1, 'sec': 1, 'second': 1,
'm': 60, 'min': 60, 'minute': 60,
'h': 3600, 'hr': 3600, 'hour': 3600,
'd': 86400, 'day': 86400,
'mo': 2592000, 'month': 2592000,
'y': 31536000, 'year': 31536000
}

# 使用正则表达式匹配数字和单位
match = re.match(r'^(\d+)/(\w+)$', limit_string)
if not match:
raise ValueError(f"Invalid rate limit format: {limit_string}")

count, unit = match.groups()
count = int(count)

# 转换单位到秒
if unit not in time_units:
raise ValueError(f"Unknown time unit: {unit}")

seconds = time_units[unit]

return (count, seconds)

class InMemoryRateLimiter:
def __init__(self):
self.requests = defaultdict(list)

async def is_rate_limited(self, key: str, limit: int, period: int) -> bool:
now = time_module.time()
self.requests[key] = [req for req in self.requests[key] if req > now - period]
if len(self.requests[key]) >= limit:
return True
self.requests[key].append(now)
return False

rate_limiter = InMemoryRateLimiter()

async def get_user_rate_limit(token: str = None):
# 这里应该实现根据 token 获取用户速率限制的逻辑
# 示例: 返回 (次数, 秒数)
config = app.state.config
api_list = app.state.api_list
api_index = api_list.index(token)
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "RATE_LIMIT")

if not token or not raw_rate_limit:
return (60, 60)

rate_limit = parse_rate_limit(raw_rate_limit)
return rate_limit

security = HTTPBearer()
async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials if credentials else None
# print("token", token)
limit, period = await get_user_rate_limit(token)

# 使用 IP 地址和 token(如果有)作为限制键
client_ip = request.client.host
rate_limit_key = f"{client_ip}:{token}" if token else client_ip

if await rate_limiter.is_rate_limited(rate_limit_key, limit, period):
raise HTTPException(status_code=429, detail="Too many requests")

def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
api_list = app.state.api_list
Expand All @@ -395,36 +463,36 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
raise HTTPException(status_code=403, detail="Permission denied")
return token

@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
return await model_handler.request_model(request, token)

@app.options("/v1/chat/completions")
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
async def options_handler():
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})

@app.get("/v1/models")
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
async def list_models(token: str = Depends(verify_api_key)):
models = post_all_models(token, app.state.config, app.state.api_list)
return JSONResponse(content={
"object": "list",
"data": models
})

@app.post("/v1/images/generations")
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
async def images_generations(
request: ImageGenerationRequest,
token: str = Depends(verify_api_key)
):
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")

@app.get("/generate-api-key")
@app.get("/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(36)
return JSONResponse(content={"api_key": api_key})

# 在 /stats 路由中返回成功和失败百分比
@app.get("/stats")
@app.get("/stats", dependencies=[Depends(rate_limit_dependency)])
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
middleware = app.middleware_stack.app
if isinstance(middleware, StatsMiddleware):
Expand Down
41 changes: 41 additions & 0 deletions test/test_rate_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import re

def parse_rate_limit(limit_string):
# 定义时间单位到秒的映射
time_units = {
's': 1, 'sec': 1, 'second': 1,
'm': 60, 'min': 60, 'minute': 60,
'h': 3600, 'hr': 3600, 'hour': 3600,
'd': 86400, 'day': 86400,
'mo': 2592000, 'month': 2592000,
'y': 31536000, 'year': 31536000
}

# 使用正则表达式匹配数字和单位
match = re.match(r'^(\d+)/(\w+)$', limit_string)
if not match:
raise ValueError(f"Invalid rate limit format: {limit_string}")

count, unit = match.groups()
count = int(count)

# 转换单位到秒
if unit not in time_units:
raise ValueError(f"Unknown time unit: {unit}")

seconds = time_units[unit]

return (count, seconds)

# 测试函数
test_cases = [
"2/min", "5/hour", "10/day", "1/second", "3/mo", "1/year",
"20/s", "15/m", "8/h", "100/d", "50/mo", "2/y"
]

for case in test_cases:
try:
result = parse_rate_limit(case)
print(f"{case} => {result}")
except ValueError as e:
print(f"Error parsing {case}: {str(e)}")

0 comments on commit b812da1

Please sign in to comment.