diff --git a/main.py b/main.py index 405c29c..1b2b73a 100644 --- a/main.py +++ b/main.py @@ -58,6 +58,8 @@ def __init__(self, app, exclude_paths=None, save_interval=3600, filename="stats. self.request_times = defaultdict(float) self.ip_counts = defaultdict(lambda: defaultdict(int)) self.request_arrivals = defaultdict(list) + self.channel_success_counts = defaultdict(int) + self.channel_failure_counts = defaultdict(int) self.lock = asyncio.Lock() self.exclude_paths = set(exclude_paths or []) self.save_interval = save_interval @@ -101,7 +103,11 @@ async def save_stats(self): "request_counts": dict(self.request_counts), "request_times": dict(self.request_times), "ip_counts": {k: dict(v) for k, v in self.ip_counts.items()}, - "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()} + "request_arrivals": {k: [t.isoformat() for t in v] for k, v in self.request_arrivals.items()}, + "channel_success_counts": dict(self.channel_success_counts), + "channel_failure_counts": dict(self.channel_failure_counts), + "channel_success_percentages": self.calculate_success_percentages(), + "channel_failure_percentages": self.calculate_failure_percentages() } filename = self.filename @@ -109,10 +115,28 @@ async def save_stats(self): await f.write(json.dumps(stats, indent=2)) self.last_save_time = current_time - # print(f"Stats saved to {filename}") + + def calculate_success_percentages(self): + percentages = {} + for channel, success_count in self.channel_success_counts.items(): + total_count = success_count + self.channel_failure_counts[channel] + if total_count > 0: + percentages[channel] = success_count / total_count * 100 + else: + percentages[channel] = 0 + return percentages + + def calculate_failure_percentages(self): + percentages = {} + for channel, failure_count in self.channel_failure_counts.items(): + total_count = failure_count + self.channel_success_counts[channel] + if total_count > 0: + percentages[channel] = failure_count / total_count * 100 + else: + percentages[channel] = 0 + return percentages async def cleanup_old_data(self): - # cutoff_time = datetime.now() - timedelta(seconds=30) cutoff_time = datetime.now() - timedelta(hours=24) async with self.lock: for endpoint in list(self.request_arrivals.keys()): @@ -139,10 +163,10 @@ async def cleanup(self): app.add_middleware(StatsMiddleware, exclude_paths=["/stats", "/generate-api-key"]) +# 在 process_request 函数中更新成功和失败计数 async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None): url = provider['base_url'] parsed_url = urlparse(url) - # print(parsed_url) engine = None if parsed_url.netloc == 'generativelanguage.googleapis.com': engine = "gemini" @@ -160,6 +184,12 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest], and "gemini" not in provider['model'][request.model]: engine = "openrouter" + if "claude" in provider['model'][request.model] and engine == "vertex": + engine = "vertex-claude" + + if "gemini" in provider['model'][request.model] and engine == "vertex": + engine = "vertex-gemini" + if endpoint == "/v1/images/generations": engine = "dalle" request.stream = False @@ -171,21 +201,28 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest], url, headers, payload = await get_payload(request, engine, provider) - # request_info = { - # "url": url, - # "headers": headers, - # "payload": payload - # } - # import json - # logger.info(f"Request details: {json.dumps(request_info, indent=4, ensure_ascii=False)}") - - if request.stream: - model = provider['model'][request.model] - generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model) - wrapped_generator = await error_handling_wrapper(generator, status_code=500) - return StreamingResponse(wrapped_generator, media_type="text/event-stream") - else: - return await anext(fetch_response(app.state.client, url, headers, payload)) + try: + if request.stream: + model = provider['model'][request.model] + generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model) + wrapped_generator = await error_handling_wrapper(generator, status_code=500) + response = StreamingResponse(wrapped_generator, media_type="text/event-stream") + else: + response = await anext(fetch_response(app.state.client, url, headers, payload)) + + # 更新成功计数 + async with app.middleware_stack.app.lock: + app.middleware_stack.app.channel_success_counts[provider['provider']] += 1 + + return response + except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e: + logger.error(f"Error with provider {provider['provider']}: {str(e)}") + + # 更新失败计数 + async with app.middleware_stack.app.lock: + app.middleware_stack.app.channel_failure_counts[provider['provider']] += 1 + + raise e import asyncio class ModelRequestHandler: @@ -270,10 +307,10 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint) + # 在 try_all_providers 函数中处理失败的情况 async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None): num_providers = len(providers) start_index = self.last_provider_index + 1 if use_round_robin else 0 - for i in range(num_providers + 1): self.last_provider_index = (start_index + i) % num_providers provider = providers[self.last_provider_index] @@ -287,7 +324,6 @@ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRe else: raise HTTPException(status_code=500, detail="Error: Current provider response failed!") - raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}") model_handler = ModelRequestHandler() @@ -341,6 +377,7 @@ def generate_api_key(): api_key = "sk-" + secrets.token_urlsafe(36) return JSONResponse(content={"api_key": api_key}) +# 在 /stats 路由中返回成功和失败百分比 @app.get("/stats") async def get_stats(request: Request, token: str = Depends(verify_admin_api_key)): middleware = app.middleware_stack.app @@ -350,7 +387,11 @@ async def get_stats(request: Request, token: str = Depends(verify_admin_api_key) "request_counts": dict(middleware.request_counts), "request_times": dict(middleware.request_times), "ip_counts": {k: dict(v) for k, v in middleware.ip_counts.items()}, - "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()} + "request_arrivals": {k: [t.isoformat() for t in v] for k, v in middleware.request_arrivals.items()}, + "channel_success_counts": dict(middleware.channel_success_counts), + "channel_failure_counts": dict(middleware.channel_failure_counts), + "channel_success_percentages": middleware.calculate_success_percentages(), + "channel_failure_percentages": middleware.calculate_failure_percentages() } return JSONResponse(content=stats) return {"error": "StatsMiddleware not found"} diff --git a/request.py b/request.py index 765c66f..cc20f79 100644 --- a/request.py +++ b/request.py @@ -10,7 +10,7 @@ async def get_image_message(base64_image, engine = None): "url": base64_image, } } - if "claude" == engine: + if "claude" == engine or "vertex-claude" == engine: return { "type": "image", "source": { @@ -19,7 +19,7 @@ async def get_image_message(base64_image, engine = None): "data": base64_image.split(",")[1], } } - if "gemini" == engine: + if "gemini" == engine or "vertex-gemini" == engine: return { "inlineData": { "mimeType": "image/jpeg", @@ -29,9 +29,9 @@ async def get_image_message(base64_image, engine = None): raise ValueError("Unknown engine") async def get_text_message(role, message, engine = None): - if "gpt" == engine or "claude" == engine or "openrouter" == engine: + if "gpt" == engine or "claude" == engine or "openrouter" == engine or "vertex-claude" == engine: return {"type": "text", "text": message} - if "gemini" == engine: + if "gemini" == engine or "vertex-gemini" == engine: return {"text": message} raise ValueError("Unknown engine") @@ -794,9 +794,9 @@ async def get_dalle_payload(request, engine, provider): async def get_payload(request: RequestModel, engine, provider): if engine == "gemini": return await get_gemini_payload(request, engine, provider) - elif engine == "vertex" and "gemini" in provider['model'][request.model]: + elif engine == "vertex-gemini": return await get_vertex_gemini_payload(request, engine, provider) - elif engine == "vertex" and "claude" in provider['model'][request.model]: + elif engine == "vertex-claude": return await get_vertex_claude_payload(request, engine, provider) elif engine == "claude": return await get_claude_payload(request, engine, provider) diff --git a/response.py b/response.py index 357869d..6d92e81 100644 --- a/response.py +++ b/response.py @@ -248,10 +248,10 @@ async def fetch_response(client, url, headers, payload): async def fetch_response_stream(client, url, headers, payload, engine, model): try: - if engine == "gemini" or (engine == "vertex" and "gemini" in model): + if engine == "gemini" or engine == "vertex-gemini": async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model): yield chunk - elif engine == "claude" or (engine == "vertex" and "claude" in model): + elif engine == "claude" or engine == "vertex-claude": async for chunk in fetch_claude_response_stream(client, url, headers, payload, model): yield chunk elif engine == "gpt": diff --git a/test/test_nostream.py b/test/test_nostream.py index 0ae7642..febb248 100644 --- a/test/test_nostream.py +++ b/test/test_nostream.py @@ -66,7 +66,7 @@ def get_model_response(image_base64): # "stream": True, "tools": tools, "tool_choice": {"type": "function", "function": {"name": "extract_underlined_text"}}, - "max_tokens": 300 + "max_tokens": 1000 } try: