Skip to content

Commit

Permalink
🐛 Bug: Fix the bug where error codes are not accurately returned to t…
Browse files Browse the repository at this point in the history
…he client.
  • Loading branch information
yym68686 committed Sep 11, 2024
1 parent 14428d9 commit 3972d74
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
20 changes: 16 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
if request.stream:
model = provider['model'][request.model]
generator = fetch_response_stream(app.state.client, url, headers, payload, engine, model)
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
wrapped_generator = await error_handling_wrapper(generator)
response = StreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
response = await anext(fetch_response(app.state.client, url, headers, payload))
Expand Down Expand Up @@ -369,6 +369,8 @@ async def request_model(self, request: Union[RequestModel, ImageGenerationReques

# 在 try_all_providers 函数中处理失败的情况
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
status_code = 500
error_message = None
num_providers = len(providers)
start_index = self.last_provider_index + 1 if use_round_robin else 0
for i in range(num_providers + 1):
Expand All @@ -377,14 +379,24 @@ async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRe
try:
response = await process_request(request, provider, endpoint)
return response
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
except HTTPException as e:
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
status_code = e.status_code
error_message = e.detail

if auto_retry:
continue
else:
raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")
except (Exception, asyncio.CancelledError, httpx.ReadError) as e:
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
error_message = str(e)
if auto_retry:
continue
else:
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")
raise HTTPException(status_code=500, detail=f"Error: Current provider response failed: {error_message}")

raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")
raise HTTPException(status_code=status_code, detail=f"All {request.model} error: {error_message}")

model_handler = ModelRequestHandler()

Expand Down
2 changes: 1 addition & 1 deletion response.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def check_response(response, error_log):
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json}
return {"error": f"{error_log} HTTP Error", "status_code": response.status_code, "details": error_json}
return None

async def fetch_gemini_response_stream(client, url, headers, payload, model):
Expand Down
8 changes: 5 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def ensure_string(item):
return str(item)

import asyncio
async def error_handling_wrapper(generator, status_code=200):
async def error_handling_wrapper(generator):
try:
first_item = await generator.__anext__()
first_item_str = first_item
Expand All @@ -126,7 +126,9 @@ async def error_handling_wrapper(generator, status_code=200):
raise StopAsyncIteration
if isinstance(first_item_str, dict) and 'error' in first_item_str:
# 如果第一个 yield 的项是错误信息,抛出 HTTPException
raise HTTPException(status_code=status_code, detail=f"{first_item_str}"[:300])
status_code = first_item_str.get('status_code', 500)
detail = first_item_str.get('details', f"{first_item_str}")
raise HTTPException(status_code=status_code, detail=f"{detail}"[:300])

# 如果不是错误,创建一个新的生成器,首先yield第一个项,然后yield剩余的项
async def new_generator():
Expand All @@ -141,7 +143,7 @@ async def new_generator():
return new_generator()

except StopAsyncIteration:
raise HTTPException(status_code=status_code, detail="data: {'error': 'No data returned'}")
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")

def post_all_models(token, config, api_list):
all_models = []
Expand Down

0 comments on commit 3972d74

Please sign in to comment.