Skip to content

Commit

Permalink
Merge pull request #3 from omegaduncan/main
Browse files Browse the repository at this point in the history
fix: Improve round-robin provider selection and add Claude max_tokens auto-fill
  • Loading branch information
yym68686 authored Sep 2, 2024
2 parents 5b7f732 + 47c28b9 commit a11cc62
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
11 changes: 7 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,13 @@ async def request_model(self, request: RequestModel, token: str):

# 检查是否启用轮询
api_index = api_list.index(token)
use_round_robin = False
auto_retry = False
use_round_robin = True
auto_retry = True
if config['api_keys'][api_index].get("preferences"):
use_round_robin = config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN")
auto_retry = config['api_keys'][api_index]["preferences"].get("AUTO_RETRY")
if config['api_keys'][api_index]["preferences"].get("USE_ROUND_ROBIN") == False:
use_round_robin = False
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
auto_retry = False

return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry)

Expand All @@ -207,6 +209,7 @@ async def try_all_providers(self, request: RequestModel, providers: List[Dict],
else:
raise HTTPException(status_code=500, detail="Error: Current provider response failed!")


raise HTTPException(status_code=500, detail=f"All providers failed: {request.model}")

model_handler = ModelRequestHandler()
Expand Down
7 changes: 7 additions & 0 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,13 @@ async def get_vertex_claude_payload(request, engine, provider):
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
}

# 檢查是否需要添加 max_tokens
if 'max_tokens' not in payload:
if "claude-3-5-sonnet" in model:
payload['max_tokens'] = 8192
elif "claude-3" in model: # 處理其他 Claude 3 模型
payload['max_tokens'] = 4096

miss_fields = [
'model',
'messages',
Expand Down

0 comments on commit a11cc62

Please sign in to comment.