From 783c658f9feaf3b2862d42fd15735beeb1fdf923 Mon Sep 17 00:00:00 2001 From: yym68686 Date: Wed, 10 Jul 2024 14:58:03 +0800 Subject: [PATCH] Fix the bug of model matching format error --- main.py | 42 ++++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index 2ce7914..5276785 100644 --- a/main.py +++ b/main.py @@ -21,7 +21,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): # 启动时的代码 - app.state.client = httpx.AsyncClient() + timeout = httpx.Timeout(connect=10.0, read=30.0, write=30.0, pool=30.0) + app.state.client = httpx.AsyncClient(timeout=timeout) yield # 关闭时的代码 await app.state.client.aclose() @@ -35,7 +36,20 @@ async def lifespan(app: FastAPI): def load_config(): try: with open('api.yaml', 'r') as f: - return yaml.safe_load(f) + conf = yaml.safe_load(f) + for index, provider in enumerate(conf['providers']): + model_dict = {} + for model in provider['model']: + if type(model) == str: + model_dict[model] = model + if type(model) == dict: + model_dict.update({value: key for key, value in model.items()}) + provider['model'] = model_dict + conf['providers'][index] = provider + api_keys_db = conf['api_keys'] + api_list = [item["api"] for item in api_keys_db] + print(json.dumps(conf, indent=4, ensure_ascii=False)) + return conf, api_keys_db, api_list except FileNotFoundError: print("配置文件 'config.yaml' 未找到。请确保文件存在于正确的位置。") return [] @@ -43,19 +57,7 @@ def load_config(): print("配置文件 'config.yaml' 格式不正确。请检查YAML格式。") return [] -config = load_config() -for index, provider in enumerate(config['providers']): - model_dict = {} - for model in provider['model']: - if type(model) == str: - model_dict[model] = model - if type(model) == dict: - model_dict.update({value: key for key, value in model.items()}) - provider['model'] = model_dict - config['providers'][index] = provider -api_keys_db = config['api_keys'] -api_list = [item["api"] for item in api_keys_db] -print(json.dumps(config, indent=4, ensure_ascii=False)) +config, api_keys_db, api_list = load_config() async def process_request(request: RequestModel, provider: Dict): print("provider: ", provider['provider']) @@ -102,7 +104,10 @@ def get_matching_providers(self, model_name, token): if "/" in model: provider_name = model.split("/")[0] model = model.split("/")[1] - if (model and model_name == model) or (model == "*"): + for provider in config['providers']: + if provider['provider'] == provider_name: + models_list = provider['model'].keys() + if (model and model_name == model) or (model == "*" and model_name in models_list): provider_rules.append(provider_name) provider_list = [] for provider in config['providers']: @@ -250,6 +255,11 @@ def generate_api_key(): api_key = "sk-" + secrets.token_urlsafe(32) return {"api_key": api_key} +async def on_fetch(request, env): + import asgi + + return await asgi.fetch(app, request, env) + if __name__ == '__main__': import uvicorn uvicorn.run("__main__:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file