Skip to content

Commit

Permalink
🐛 Bug: Fix the bug that the old rate limit code did not remove.
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Nov 25, 2024
1 parent 6bc5b55 commit 67e1a23
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 46 deletions.
42 changes: 10 additions & 32 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
save_api_yaml,
get_model_dict,
post_all_models,
get_user_rate_limit,
circular_list_encoder,
error_handling_wrapper,
rate_limiter,
Expand Down Expand Up @@ -1199,27 +1198,6 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques

security = HTTPBearer()

async def rate_limit_dependency(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
token = credentials.credentials if credentials else None
api_list = app.state.api_list
try:
api_index = api_list.index(token)
except ValueError:
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
if api_index is None:
print("error: Invalid or missing API Key:", token)
api_index = None
token = None

# 使用 IP 地址和 token(如果有)作为限制键
client_ip = request.client.host
rate_limit_key = f"{client_ip}:{token}" if token else client_ip

limits = await get_user_rate_limit(app, api_index)
if await rate_limiter.is_rate_limited(rate_limit_key, limits):
raise HTTPException(status_code=429, detail="Too many requests")

def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
api_list = app.state.api_list
token = credentials.credentials
Expand Down Expand Up @@ -1250,44 +1228,44 @@ def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(sec
raise HTTPException(status_code=403, detail="Permission denied")
return token

@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/chat/completions")
async def request_model(request: RequestModel, api_index: int = Depends(verify_api_key)):
return await model_handler.request_model(request, api_index)

@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
@app.options("/v1/chat/completions")
async def options_handler():
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})

@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
@app.get("/v1/models")
async def list_models(api_index: int = Depends(verify_api_key)):
models = post_all_models(api_index, app.state.config)
return JSONResponse(content={
"object": "list",
"data": models
})

@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/images/generations")
async def images_generations(
request: ImageGenerationRequest,
api_index: int = Depends(verify_api_key)
):
return await model_handler.request_model(request, api_index, endpoint="/v1/images/generations")

@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/embeddings")
async def embeddings(
request: EmbeddingRequest,
api_index: int = Depends(verify_api_key)
):
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")

@app.post("/v1/audio/speech", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/audio/speech")
async def audio_speech(
request: TextToSpeechRequest,
api_index: str = Depends(verify_api_key)
):
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/speech")

@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/moderations")
async def moderations(
request: ModerationRequest,
api_index: int = Depends(verify_api_key)
Expand All @@ -1296,7 +1274,7 @@ async def moderations(

from fastapi import UploadFile, File, Form, HTTPException
import io
@app.post("/v1/audio/transcriptions", dependencies=[Depends(rate_limit_dependency)])
@app.post("/v1/audio/transcriptions")
async def audio_transcriptions(
file: UploadFile = File(...),
model: str = Form(...),
Expand All @@ -1322,7 +1300,7 @@ async def audio_transcriptions(
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Error processing audio file: {str(e)}")

@app.get("/v1/generate-api-key", dependencies=[Depends(rate_limit_dependency)])
@app.get("/v1/generate-api-key")
def generate_api_key():
# Define the character set (only alphanumeric)
chars = string.ascii_letters + string.digits
Expand All @@ -1336,7 +1314,7 @@ def generate_api_key():
from sqlalchemy import func, desc, case
from fastapi import Query

@app.get("/v1/stats", dependencies=[Depends(rate_limit_dependency)])
@app.get("/v1/stats")
async def get_stats(
request: Request,
token: str = Depends(verify_admin_api_key),
Expand Down
14 changes: 0 additions & 14 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,6 @@ async def is_rate_limited(self, key: str, limits) -> bool:

rate_limiter = InMemoryRateLimiter()

async def get_user_rate_limit(app, api_index: int = None):
# 这里应该实现根据 token 获取用户速率限制的逻辑
# 示例: 返回 (次数, 秒数)
config = app.state.config
raw_rate_limit = safe_get(config, 'api_keys', api_index, "preferences", "rate_limit")
# print("raw_rate_limit", raw_rate_limit)
# print("not api_index or not raw_rate_limit", api_index == None, not raw_rate_limit, api_index == None or not raw_rate_limit, api_index, raw_rate_limit)

if api_index == None or not raw_rate_limit:
return [(30, 60)]

rate_limit = parse_rate_limit(raw_rate_limit)
return rate_limit

import asyncio

class ThreadSafeCircularList:
Expand Down

0 comments on commit 67e1a23

Please sign in to comment.