Skip to content

Commit

Permalink
Fix the bug of model matching format error
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Jul 10, 2024
1 parent b8a7df8 commit 783c658
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时的代码
app.state.client = httpx.AsyncClient()
timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0)
app.state.client = httpx.AsyncClient(timeout=timeout)
yield
# 关闭时的代码
await app.state.client.aclose()
Expand All @@ -35,27 +36,28 @@ async def lifespan(app: FastAPI):
def load_config():
try:
with open('api.yaml', 'r') as f:
return yaml.safe_load(f)
conf = yaml.safe_load(f)
for index, provider in enumerate(conf['providers']):
model_dict = {}
for model in provider['model']:
if type(model) == str:
model_dict[model] = model
if type(model) == dict:
model_dict.update({value: key for key, value in model.items()})
provider['model'] = model_dict
conf['providers'][index] = provider
api_keys_db = conf['api_keys']
api_list = [item["api"] for item in api_keys_db]
print(json.dumps(conf, indent=4, ensure_ascii=False))
return conf, api_keys_db, api_list
except FileNotFoundError:
print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。")
return []
except yaml.YAMLError:
print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。")
return []

config = load_config()
for index, provider in enumerate(config['providers']):
model_dict = {}
for model in provider['model']:
if type(model) == str:
model_dict[model] = model
if type(model) == dict:
model_dict.update({value: key for key, value in model.items()})
provider['model'] = model_dict
config['providers'][index] = provider
api_keys_db = config['api_keys']
api_list = [item["api"] for item in api_keys_db]
print(json.dumps(config, indent=4, ensure_ascii=False))
config, api_keys_db, api_list = load_config()

async def process_request(request: RequestModel, provider: Dict):
print("provider: ", provider['provider'])
Expand Down Expand Up @@ -102,7 +104,10 @@ def get_matching_providers(self, model_name, token):
if "/" in model:
provider_name = model.split("/")[0]
model = model.split("/")[1]
if (model and model_name == model) or (model == "*"):
for provider in config['providers']:
if provider['provider'] == provider_name:
models_list = provider['model'].keys()
if (model and model_name == model) or (model == "*" and model_name in models_list):
provider_rules.append(provider_name)
provider_list = []
for provider in config['providers']:
Expand Down Expand Up @@ -250,6 +255,11 @@ def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(32)
return {"api_key": api_key}

async def on_fetch(request, env):
import asgi

return await asgi.fetch(app, request, env)

if __name__ == '__main__':
import uvicorn
uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True)

0 comments on commit 783c658

Please sign in to comment.