Skip to content

Commit

Permalink
✨ Feature: 1. Support o1 model streaming output
Browse files Browse the repository at this point in the history
2. Support felo o1 model reverse API
  • Loading branch information
yym68686 committed Nov 26, 2024
1 parent 7534fec commit 5bec9fd
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 65 deletions.
5 changes: 1 addition & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,7 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A

if "claude" not in original_model \
and "gpt" not in original_model \
and "o1" not in original_model \
and "gemini" not in original_model \
and parsed_url.netloc != 'api.cloudflare.com' \
and parsed_url.netloc != 'api.cohere.com':
Expand All @@ -845,10 +846,6 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
if "gemini" in original_model and engine == "vertex":
engine = "vertex-gemini"

if "o1-preview" in original_model or "o1-mini" in original_model:
engine = "o1"
request.stream = False

if endpoint == "/v1/images/generations":
engine = "dalle"
request.stream = False
Expand Down
63 changes: 2 additions & 61 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ async def get_gpt_payload(request, engine, provider):
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
if provider['base_url'] == "https://api-ext.felo.ai/one-ai/completions":
if provider['base_url'] == "https://api-ext.felo.ai/one-ai/completions" or provider['base_url'] == "https://api-ext.felo.ai/trail/v1/chat/completions":
headers['Authorization'] = f"{await provider_api_circular_list[provider['provider']].next(model)}"
else:
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
Expand Down Expand Up @@ -679,7 +679,7 @@ async def get_gpt_payload(request, engine, provider):
if field not in miss_fields and value is not None:
payload[field] = value

if provider.get("tools") == False:
if provider.get("tools") == False or "o1" in model:
payload.pop("tools", None)
payload.pop("tool_choice", None)

Expand Down Expand Up @@ -869,63 +869,6 @@ async def get_cloudflare_payload(request, engine, provider):

return url, headers, payload

async def get_o1_payload(request, engine, provider):
headers = {
'Content-Type': 'application/json'
}
model_dict = get_model_dict(provider)
model = model_dict[request.model]
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"

url = provider['base_url']

messages = []
for msg in request.messages:
if isinstance(msg.content, list):
content = []
for item in msg.content:
if item.type == "text":
text_message = await get_text_message(msg.role, item.text, engine)
content.append(text_message)
else:
content = msg.content

if isinstance(content, list) and msg.role != "system":
for item in content:
if item["type"] == "text":
messages.append({"role": msg.role, "content": item["text"]})
elif msg.role != "system":
messages.append({"role": msg.role, "content": content})

payload = {
"model": model,
"messages": messages,
}

miss_fields = [
'model',
'messages',
'tools',
'tool_choice',
'temperature',
'top_p',
'max_tokens',
'presence_penalty',
'frequency_penalty',
'n',
'user',
'include_usage',
'logprobs',
'top_logprobs'
]

for field, value in request.model_dump(exclude_unset=True).items():
if field not in miss_fields and value is not None:
payload[field] = value

return url, headers, payload

async def gpt2claude_tools_json(json_dict):
import copy
json_dict = copy.deepcopy(json_dict)
Expand Down Expand Up @@ -1213,8 +1156,6 @@ async def get_payload(request: RequestModel, engine, provider):
return await get_openrouter_payload(request, engine, provider)
elif engine == "cloudflare":
return await get_cloudflare_payload(request, engine, provider)
elif engine == "o1":
return await get_o1_payload(request, engine, provider)
elif engine == "cohere":
return await get_cohere_payload(request, engine, provider)
elif engine == "dalle":
Expand Down

0 comments on commit 5bec9fd

Please sign in to comment.