Skip to content

Commit

Permalink
✨ Feature: Add features: Add API channel success rate statistics, cha…
Browse files Browse the repository at this point in the history
…nnel status records.
  • Loading branch information
yym68686 committed Sep 5, 2024
1 parent 44caf41 commit 73a667f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 31 deletions.
85 changes: 63 additions & 22 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats.
self.request_times = defaultdict(float)
self.ip_counts = defaultdict(lambda: defaultdict(int))
self.request_arrivals = defaultdict(list)
self.channel_success_counts = defaultdict(int)
self.channel_failure_counts = defaultdict(int)
self.lock = asyncio.Lock()
self.exclude_paths = set(exclude_paths or [])
self.save_interval = save_interval
Expand Down Expand Up @@ -101,18 +103,40 @@ async def save_stats(self):
"request_counts": dict(self.request_counts),
"request_times": dict(self.request_times),
"ip_counts": {k: dict(v) for k, v in self.ip_counts.items()},
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()},
"channel_success_counts": dict(self.channel_success_counts),
"channel_failure_counts": dict(self.channel_failure_counts),
"channel_success_percentages": self.calculate_success_percentages(),
"channel_failure_percentages": self.calculate_failure_percentages()
}

filename = self.filename
async with aiofiles.open(filename, mode='w') as f:
await f.write(json.dumps(stats, indent=2))

self.last_save_time = current_time
# print(f"Stats saved to {filename}")

def calculate_success_percentages(self):
percentages = {}
for channel, success_count in self.channel_success_counts.items():
total_count = success_count + self.channel_failure_counts[channel]
if total_count > 0:
percentages[channel] = success_count / total_count * 100
else:
percentages[channel] = 0
return percentages

def calculate_failure_percentages(self):
percentages = {}
for channel, failure_count in self.channel_failure_counts.items():
total_count = failure_count + self.channel_success_counts[channel]
if total_count > 0:
percentages[channel] = failure_count / total_count * 100
else:
percentages[channel] = 0
return percentages

async def cleanup_old_data(self):
# cutoff_time = datetime.now() - timedelta(seconds=30)
cutoff_time = datetime.now() - timedelta(hours=24)
async with self.lock:
for endpoint in list(self.request_arrivals.keys()):
Expand All @@ -139,10 +163,10 @@ async def cleanup(self):

app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"])

# 在 process_request 函数中更新成功和失败计数
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
url = provider['base_url']
parsed_url = urlparse(url)
# print(parsed_url)
engine = None
if parsed_url.netloc == 'generativelanguage.googleapis.com':
engine = "gemini"
Expand All @@ -160,6 +184,12 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
and "gemini" not in provider['model'][request.model]:
engine = "openrouter"

if "claude" in provider['model'][request.model] and engine == "vertex":
engine = "vertex-claude"

if "gemini" in provider['model'][request.model] and engine == "vertex":
engine = "vertex-gemini"

if endpoint == "/v1/images/generations":
engine = "dalle"
request.stream = False
Expand All @@ -171,21 +201,28 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],

url, headers, payload = await get_payload(request, engine, provider)

# request_info = {
# "url": url,
# "headers": headers,
# "payload": payload
# }
# import json
# logger.info(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}")

if request.stream:
model = provider['model'][request.model]
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
return await anext(fetch_response(app.state.client, url, headers, payload))
try:
if request.stream:
model = provider['model'][request.model]
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
response = await anext(fetch_response(app.state.client, url, headers, payload))

# 更新成功计数
async with app.middleware_stack.app.lock:
app.middleware_stack.app.channel_success_counts[provider['provider']] += 1

return response
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
logger.error(f"Error with provider {provider['provider']}: {str(e)}")

# 更新失败计数
async with app.middleware_stack.app.lock:
app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1

raise e

import asyncio
class ModelRequestHandler:
Expand Down Expand Up @@ -270,10 +307,10 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques

return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)

# 在 try_all_providers 函数中处理失败的情况
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
num_providers = len(providers)
start_index = self.last_provider_index + 1 if use_round_robin else 0

for i in range(num_providers + 1):
self.last_provider_index = (start_index + i) % num_providers
provider = providers[self.last_provider_index]
Expand All @@ -287,7 +324,6 @@ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRe
else:
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")


raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")

model_handler = ModelRequestHandler()
Expand Down Expand Up @@ -341,6 +377,7 @@ def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(36)
return JSONResponse(content={"api_key": api_key})

# 在 /stats 路由中返回成功和失败百分比
@app.get("/stats")
async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)):
middleware = app.middleware_stack.app
Expand All @@ -350,7 +387,11 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)
"request_counts": dict(middleware.request_counts),
"request_times": dict(middleware.request_times),
"ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()},
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}
"request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()},
"channel_success_counts": dict(middleware.channel_success_counts),
"channel_failure_counts": dict(middleware.channel_failure_counts),
"channel_success_percentages": middleware.calculate_success_percentages(),
"channel_failure_percentages": middleware.calculate_failure_percentages()
}
return JSONResponse(content=stats)
return {"error": "StatsMiddleware not found"}
Expand Down
12 changes: 6 additions & 6 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def get_image_message(base64_image, engine = None):
"url": base64_image,
}
}
if "claude" == engine:
if "claude" == engine or "vertex-claude" == engine:
return {
"type": "image",
"source": {
Expand All @@ -19,7 +19,7 @@ async def get_image_message(base64_image, engine = None):
"data": base64_image.split(",")[1],
}
}
if "gemini" == engine:
if "gemini" == engine or "vertex-gemini" == engine:
return {
"inlineData": {
"mimeType": "image/jpeg",
Expand All @@ -29,9 +29,9 @@ async def get_image_message(base64_image, engine = None):
raise ValueError("Unknown engine")

async def get_text_message(role, message, engine = None):
if "gpt" == engine or "claude" == engine or "openrouter" == engine:
if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine:
return {"type": "text", "text": message}
if "gemini" == engine:
if "gemini" == engine or "vertex-gemini" == engine:
return {"text": message}
raise ValueError("Unknown engine")

Expand Down Expand Up @@ -794,9 +794,9 @@ async def get_dalle_payload(request, engine, provider):
async def get_payload(request: RequestModel, engine, provider):
if engine == "gemini":
return await get_gemini_payload(request, engine, provider)
elif engine == "vertex" and "gemini" in provider['model'][request.model]:
elif engine == "vertex-gemini":
return await get_vertex_gemini_payload(request, engine, provider)
elif engine == "vertex" and "claude" in provider['model'][request.model]:
elif engine == "vertex-claude":
return await get_vertex_claude_payload(request, engine, provider)
elif engine == "claude":
return await get_claude_payload(request, engine, provider)
Expand Down
4 changes: 2 additions & 2 deletions response.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,10 @@ async def fetch_response(client, url, headers, payload):

async def fetch_response_stream(client, url, headers, payload, engine, model):
try:
if engine == "gemini" or (engine == "vertex" and "gemini" in model):
if engine == "gemini" or engine == "vertex-gemini":
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
yield chunk
elif engine == "claude" or (engine == "vertex" and "claude" in model):
elif engine == "claude" or engine == "vertex-claude":
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
yield chunk
elif engine == "gpt":
Expand Down
2 changes: 1 addition & 1 deletion test/test_nostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_model_response(image_base64):
# "stream": True,
"tools": tools,
"tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}},
"max_tokens": 300
"max_tokens": 1000
}

try:
Expand Down

0 comments on commit 73a667f

Please sign in to comment.