From 3972d748d7ac6cebb8e2fdb878a74927ec9a9fd9 Mon Sep 17 00:00:00 2001 From: yym68686 Date: Thu, 12 Sep 2024 04:06:33 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20Bug:=20Fix=20the=20bug=20where?= =?UTF-8?q?=20error=20codes=20are=20not=20accurately=20returned=20to=20the?= =?UTF-8?q?=20client.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 20 ++++++++++++++++---- response.py | 2 +- utils.py | 8 +++++--- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/main.py b/main.py index 64dc24c..e49ea1d 100644 --- a/main.py +++ b/main.py @@ -218,7 +218,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest], if request.stream: model = provider['model'][request.model] generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model) - wrapped_generator = await error_handling_wrapper(generator, status_code=500) + wrapped_generator = await error_handling_wrapper(generator) response = StreamingResponse(wrapped_generator, media_type="text/event-stream") else: response = await anext(fetch_response(app.state.client, url, headers, payload)) @@ -369,6 +369,8 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques # 在 try_all_providers 函数中处理失败的情况 async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None): + status_code = 500 + error_message = None num_providers = len(providers) start_index = self.last_provider_index + 1 if use_round_robin else 0 for i in range(num_providers + 1): @@ -377,14 +379,24 @@ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRe try: response = await process_request(request, provider, endpoint) return response - except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e: + except HTTPException as e: logger.error(f"Error with provider {provider['provider']}: {str(e)}") + status_code = e.status_code + error_message = e.detail + + if auto_retry: + continue + else: + raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}") + except (Exception, asyncio.CancelledError, httpx.ReadError) as e: + logger.error(f"Error with provider {provider['provider']}: {str(e)}") + error_message = str(e) if auto_retry: continue else: - raise HTTPException(status_code=500, detail="Error: Current provider response failed!") + raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}") - raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}") + raise HTTPException(status_code=status_code, detail=f"All {request.model} error: {error_message}") model_handler = ModelRequestHandler() diff --git a/response.py b/response.py index d33d2fc..be2f1c3 100644 --- a/response.py +++ b/response.py @@ -48,7 +48,7 @@ async def check_response(response, error_log): error_json = json.loads(error_str) except json.JSONDecodeError: error_json = error_str - return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json} + return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json} return None async def fetch_gemini_response_stream(client, url, headers, payload, model): diff --git a/utils.py b/utils.py index 366a3a2..2882f9e 100644 --- a/utils.py +++ b/utils.py @@ -104,7 +104,7 @@ def ensure_string(item): return str(item) import asyncio -async def error_handling_wrapper(generator, status_code=200): +async def error_handling_wrapper(generator): try: first_item = await generator.__anext__() first_item_str = first_item @@ -126,7 +126,9 @@ async def error_handling_wrapper(generator, status_code=200): raise StopAsyncIteration if isinstance(first_item_str, dict) and 'error' in first_item_str: # 如果第一个 yield 的项是错误信息,抛出 HTTPException - raise HTTPException(status_code=status_code, detail=f"{first_item_str}"[:300]) + status_code = first_item_str.get('status_code', 500) + detail = first_item_str.get('details', f"{first_item_str}") + raise HTTPException(status_code=status_code, detail=f"{detail}"[:300]) # 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项 async def new_generator(): @@ -141,7 +143,7 @@ async def new_generator(): return new_generator() except StopAsyncIteration: - raise HTTPException(status_code=status_code, detail="data: {'error': 'No data returned'}") + raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}") def post_all_models(token, config, api_list): all_models = []