Skip to content

Commit

Permalink
🤖 Models: Add support for o1-mini o1-preview model
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 13, 2024
1 parent f4d6dda commit 1126d73
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
4 changes: 4 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest],
if "gemini" in provider['model'][request.model] and engine == "vertex":
engine = "vertex-gemini"

if "o1-preview" in provider['model'][request.model] or "o1-mini" in provider['model'][request.model]:
engine = "o1"
request.stream = False

if endpoint == "/v1/images/generations":
engine = "dalle"
request.stream = False
Expand Down
58 changes: 58 additions & 0 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,62 @@ async def get_cloudflare_payload(request, engine, provider):

return url, headers, payload

async def get_o1_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {provider['api'].next()}"

url = provider['base_url']

messages = []
for msg in request.messages:
if isinstance(msg.content, list):
content = []
for item in msg.content:
if item.type == "text":
text_message = await get_text_message(msg.role, item.text, engine)
content.append(text_message)
else:
content = msg.content

if isinstance(content, list):
for item in content:
if item["type"] == "text":
messages.append({"role": msg.role, "content": item["text"]})
else:
messages.append({"role": msg.role, "content": content})

model = provider['model'][request.model]
payload = {
"model": model,
"messages": messages,
}

miss_fields = [
'model',
'messages',
'tools',
'tool_choice',
'temperature',
'top_p',
'max_tokens',
'presence_penalty',
'frequency_penalty',
'n',
'user',
'include_usage',
'logprobs',
'top_logprobs'
]

for field, value in request.model_dump(exclude_unset=True).items():
if field not in miss_fields and value is not None:
payload[field] = value

return url, headers, payload

async def gpt2claude_tools_json(json_dict):
import copy
json_dict = copy.deepcopy(json_dict)
Expand Down Expand Up @@ -929,6 +985,8 @@ async def get_payload(request: RequestModel, engine, provider):
return await get_openrouter_payload(request, engine, provider)
elif engine == "cloudflare":
return await get_cloudflare_payload(request, engine, provider)
elif engine == "o1":
return await get_o1_payload(request, engine, provider)
elif engine == "dalle":
return await get_dalle_payload(request, engine, provider)
else:
Expand Down
9 changes: 9 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def update_config(config_data):
async def load_config(app=None):
import yaml
try:
# with open('./api.yaml', 'r') as f:
# tokens = yaml.scan(f)
# for token in tokens:
# if isinstance(token, yaml.ScalarToken):
# value = token.value
# # 如果plain为False,表示字符串被引号包裹
# is_quoted = not token.plain
# print(f"值: {value}, 是否被引号包裹: {is_quoted}")

with open('./api.yaml', 'r') as f:
# 判断是否为空文件
conf = yaml.safe_load(f)
Expand Down

0 comments on commit 1126d73

Please sign in to comment.