diff --git a/main.py b/main.py index 31cf49e..f801002 100644 --- a/main.py +++ b/main.py @@ -201,6 +201,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest], if "gemini" in provider['model'][request.model] and engine == "vertex": engine = "vertex-gemini" + if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]: + engine = "o1" + request.stream = False + if endpoint == "/v1/images/generations": engine = "dalle" request.stream = False diff --git a/request.py b/request.py index 92541ad..788fd50 100644 --- a/request.py +++ b/request.py @@ -737,6 +737,62 @@ async def get_cloudflare_payload(request, engine, provider): return url, headers, payload +async def get_o1_payload(request, engine, provider): + headers = { + 'Content-Type': 'application/json' + } + if provider.get("api"): + headers['Authorization'] = f"Bearer {provider['api'].next()}" + + url = provider['base_url'] + + messages = [] + for msg in request.messages: + if isinstance(msg.content, list): + content = [] + for item in msg.content: + if item.type == "text": + text_message = await get_text_message(msg.role, item.text, engine) + content.append(text_message) + else: + content = msg.content + + if isinstance(content, list): + for item in content: + if item["type"] == "text": + messages.append({"role": msg.role, "content": item["text"]}) + else: + messages.append({"role": msg.role, "content": content}) + + model = provider['model'][request.model] + payload = { + "model": model, + "messages": messages, + } + + miss_fields = [ + 'model', + 'messages', + 'tools', + 'tool_choice', + 'temperature', + 'top_p', + 'max_tokens', + 'presence_penalty', + 'frequency_penalty', + 'n', + 'user', + 'include_usage', + 'logprobs', + 'top_logprobs' + ] + + for field, value in request.model_dump(exclude_unset=True).items(): + if field not in miss_fields and value is not None: + payload[field] = value + + return url, headers, payload + async def gpt2claude_tools_json(json_dict): import copy json_dict = copy.deepcopy(json_dict) @@ -929,6 +985,8 @@ async def get_payload(request: RequestModel, engine, provider): return await get_openrouter_payload(request, engine, provider) elif engine == "cloudflare": return await get_cloudflare_payload(request, engine, provider) + elif engine == "o1": + return await get_o1_payload(request, engine, provider) elif engine == "dalle": return await get_dalle_payload(request, engine, provider) else: diff --git a/utils.py b/utils.py index 2882f9e..a62d9bc 100644 --- a/utils.py +++ b/utils.py @@ -53,6 +53,15 @@ def update_config(config_data): async def load_config(app=None): import yaml try: + # with open('./api.yaml', 'r') as f: + # tokens = yaml.scan(f) + # for token in tokens: + # if isinstance(token, yaml.ScalarToken): + # value = token.value + # # 如果plain为False,表示字符串被引号包裹 + # is_quoted = not token.plain + # print(f"值: {value}, 是否被引号包裹: {is_quoted}") + with open('./api.yaml', 'r') as f: # 判断是否为空文件 conf = yaml.safe_load(f)