Skip to content

Commit

Permalink
✨ Feature: Add feature: support wildcard matching like gpt* to match …
Browse files Browse the repository at this point in the history
…models such as gpt-3.5 and gpt-4.
  • Loading branch information
yym68686 committed Oct 24, 2024
1 parent 0ebec22 commit 60014c4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 24 deletions.
56 changes: 32 additions & 24 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from log_config import logger

import re
import copy
import httpx
import secrets
from time import time
Expand Down Expand Up @@ -652,44 +653,51 @@ def get_matching_providers(self, model_name, token):
# print("model_name", model_name)
# print("model_name_split", model_name_split)
# print("model", model)

# api_keys 中 model 为 provider_name/* 时,表示所有模型都匹配
if model_name_split == "*":
if model_name in models_list:
provider_rules.append(provider_name)
elif model_name_split == model_name:
if model_name in models_list:
provider_rules.append(provider_name)

# 如果请求模型名: gpt-4* ,则匹配所有以模型名开头且不以 * 结尾的模型
for models_list_model in models_list:
if model_name.endswith("*") and models_list_model.startswith(model_name.rstrip("*")):
provider_rules.append(provider_name + "/" + models_list_model)

# api_keys 中 model 为 provider_name/model_name 时,表示模型名完全匹配
elif model_name_split == model_name \
or (model_name.endswith("*") and model_name_split.startswith(model_name.rstrip("*"))): # api_keys 中 model 为 provider_name/model_name 时,请求模型名: model_name*
if model_name_split in models_list:
provider_rules.append(provider_name + "/" + model_name_split)

else:
for provider in config['providers']:
for provider in config["providers"]:
model_dict = get_model_dict(provider)
if model in model_dict.keys():
provider_rules.append(provider['provider'] + "/" + model)
provider_rules.append(provider["provider"] + "/" + model)

provider_list = []
# print("provider_rules", provider_rules)
for item in provider_rules:
for provider in config['providers']:
# print("provider", provider, provider['provider'] == item, item)
if "/" in item:
if provider['provider'] == item.split("/")[0]:
model_dict = get_model_dict(provider)
if model_name in model_dict.keys() and "/".join(item.split("/")[1:]) == model_name:
provider_list.append(provider)
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
elif provider['provider'] == item:
if provider['provider'] == item.split("/")[0]:
new_provider = copy.deepcopy(provider)
model_dict = get_model_dict(provider)
# print("model_dict", model_dict)
model_name_split = "/".join(item.split("/")[1:])
if model_name in model_dict.keys():
provider_list.append(provider)
else:
pass

# if provider['provider'] == item:
# if "/" in item:
# if item.split("/")[1] == model_name:
# provider_list.append(provider)
# else:
# model_dict = get_model_dict(provider)
# if model_name in model_dict.keys():
# provider_list.append(provider)
if "/" in item and model_name_split == model_name:
new_provider["model"] = [{model_dict[model_name]: model_name}]
# 如果 item 不包含 /,则直接匹配 provider,说明整个渠道所有模型都能用
provider_list.append(new_provider)

elif model_name.endswith("*") and "/" in item and model_name_split.startswith(model_name.rstrip("*")):
# old: new
new_provider["model"] = [{model_dict[model_name_split]: model_name}]
provider_list.append(new_provider)

# print("provider_list", provider_list)
return provider_list

async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
Expand Down
12 changes: 12 additions & 0 deletions test/test_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
a = [
{"a": 1, "b": 2, "c": 3},
{"a": 4, "b": 5, "c": 6},
{"a": 7, "b": 8, "c": 9}
]
import copy
for item in a:
new_item = copy.deepcopy(item)
new_item["a"] = 10
del new_item["b"]
# print(item)
print(a)

0 comments on commit 60014c4

Please sign in to comment.