Skip to content

Commit

Permalink
Add feature: support OpenAI dall-e-3 image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Sep 2, 2024
1 parent 46ce910 commit 2ec0842
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 60 deletions.
23 changes: 14 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@

## Introduction

这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、DeepBricks、OpenRouter 等。
如果个人使用的话,one/new-api 过于复杂,有很多个人不需要使用的商用功能,如果你不想要复杂的前端界面,有想要支持的模型多一点,可以试试 uni-api。这是一个统一管理大模型API的项目,可以通过一个统一的API接口调用多个后端服务,统一转换为 OpenAI 格式,支持负载均衡。目前支持的后端服务有:OpenAI、Anthropic、Gemini、Vertex、DeepBricks、OpenRouter 等。

## Features

- 统一管理多个后端服务
- 支持负载均衡
- 支持 OpenAI, Anthropic, Gemini, Vertex 函数调用
- 支持多个模型
- 支持多个 API Key
- 支持 Vertex 区域负载均衡,支持 Vertex 高并发
- 无前端,纯配置文件配置 API 渠道。只要写一个文件就能运行起一个属于自己的 API 站,文档有详细的配置指南,小白友好。
- 统一管理多个后端服务,支持 OpenAI、Deepseek、DeepBricks、OpenRouter 等其他API 是 OpenAI 格式的提供商。支持 OpenAI Dalle-3 图像生成。
- 同时支持 Anthropic、Gemini、Vertex API。Vertex 同时支持 Claude 和 Gemini API。
- 支持 OpenAI、 Anthropic、Gemini、Vertex 原生 tool use 函数调用。
- 支持 OpenAI、Anthropic、Gemini、Vertex 原生识图 API。
- 支持负载均衡,支持 Vertex 区域负载均衡,支持 Vertex 高并发,最高可将 Gemini,Claude 并发提高 (API数量 * 区域数量) 倍。除了 Vertex 区域负载均衡,所有 API 均支持渠道级负载均衡,提高沉浸式翻译体验。
- 支持自动重试,当一个 API 渠道响应失败时,自动重试下一个 API 渠道。
- 支持细粒度的权限控制。支持使用通配符设置 API key 可用渠道的特定模型。
- 支持多个 API Key。

## Configuration

使用api.yaml配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:
使用 api.yaml 配置文件,可以配置多个模型,每个模型可以配置多个后端服务,支持负载均衡。下面是 api.yaml 配置文件的示例:

```yaml
providers:
Expand All @@ -35,6 +38,7 @@ providers:
model: # 至少填一个模型
- gpt-4o # 可以使用的模型名称,必填
- claude-3-5-sonnet-20240620: claude-3-5-sonnet # 重命名模型,claude-3-5-sonnet-20240620 是服务商的模型名称,claude-3-5-sonnet 是重命名后的名字,可以使用简洁的名字代替原来复杂的名称,选填
- dall-e-3

- provider: anthropic
base_url: https://api.anthropic.com/v1/messages
Expand Down Expand Up @@ -86,7 +90,7 @@ api_keys:
model:
- anthropic/claude-3-5-sonnet # 可以使用的模型名称,仅可以使用名为 anthropic 提供商提供的 claude-3-5-sonnet 模型。其他提供商的 claude-3-5-sonnet 模型不可以使用。
preferences:
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true
USE_ROUND_ROBIN: true # 是否使用轮询负载均衡,true 为使用,false 为不使用,默认为 true。开启轮训后每次请求模型按照 model 配置的顺序依次请求。与 providers 里面原始的渠道顺序无关。因此你可以设置每个 API key 请求顺序不一样。
AUTO_RETRY: true # 是否自动重试,自动重试下一个提供商,true 为自动重试,false 为不自动重试,默认为 true
```
Expand Down Expand Up @@ -152,6 +156,7 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
-d '{"model": "gpt-4o","messages": [{"role": "user", "content": "Hello"}],"stream": true}'
```


## Star History

<a href="https://github.com/yym68686/uni-api/stargazers">
Expand Down
31 changes: 21 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from contextlib import asynccontextmanager

from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi import FastAPI, HTTPException, Depends
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

from models import RequestModel
from models import RequestModel, ImageGenerationRequest
from utils import error_handling_wrapper, get_all_models, post_all_models, load_config
from request import get_payload
from response import fetch_response, fetch_response_stream

from typing import List, Dict
from typing import List, Dict, Union
from urllib.parse import urlparse

@asynccontextmanager
Expand Down Expand Up @@ -80,7 +80,7 @@ async def lifespan(app: FastAPI):
allow_headers=["*"], # 允许所有头部字段
)

async def process_request(request: RequestModel, provider: Dict):
async def process_request(request: Union[RequestModel, ImageGenerationRequest], provider: Dict, endpoint=None):
url = provider['base_url']
parsed_url = urlparse(url)
# print(parsed_url)
Expand All @@ -101,6 +101,10 @@ async def process_request(request: RequestModel, provider: Dict):
and "gemini" not in provider['model'][request.model]:
engine = "openrouter"

if endpoint == "/v1/images/generations":
engine = "dalle"
request.stream = False

if provider.get("engine"):
engine = provider["engine"]

Expand All @@ -122,7 +126,7 @@ async def process_request(request: RequestModel, provider: Dict):
wrapped_generator = await error_handling_wrapper(generator, status_code=500)
return StreamingResponse(wrapped_generator, media_type="text/event-stream")
else:
return await fetch_response(app.state.client, url, headers, payload)
return await anext(fetch_response(app.state.client, url, headers, payload))

import asyncio
class ModelRequestHandler:
Expand Down Expand Up @@ -171,7 +175,7 @@ def get_matching_providers(self, model_name, token):
# print(json.dumps(provider, indent=4, ensure_ascii=False))
return provider_list

async def request_model(self, request: RequestModel, token: str):
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest], token: str, endpoint=None):
config = app.state.config
# api_keys_db = app.state.api_keys_db
api_list = app.state.api_list
Expand All @@ -193,17 +197,17 @@ async def request_model(self, request: RequestModel, token: str):
if config['api_keys'][api_index]["preferences"].get("AUTO_RETRY") == False:
auto_retry = False

return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry)
return await self.try_all_providers(request, matching_providers, use_round_robin, auto_retry, endpoint)

async def try_all_providers(self, request: RequestModel, providers: List[Dict], use_round_robin: bool, auto_retry: bool):
async def try_all_providers(self, request: Union[RequestModel, ImageGenerationRequest], providers: List[Dict], use_round_robin: bool, auto_retry: bool, endpoint: str = None):
num_providers = len(providers)
start_index = self.last_provider_index + 1 if use_round_robin else 0

for i in range(num_providers + 1):
self.last_provider_index = (start_index + i) % num_providers
provider = providers[self.last_provider_index]
try:
response = await process_request(request, provider)
response = await process_request(request, provider, endpoint)
return response
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError) as e:
logger.error(f"Error with provider {provider['provider']}: {str(e)}")
Expand All @@ -228,7 +232,7 @@ def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)
return token

@app.post("/v1/chat/completions")
async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
async def request_model(request: Union[RequestModel, ImageGenerationRequest], token: str = Depends(verify_api_key)):
return await model_handler.request_model(request, token)

@app.options("/v1/chat/completions")
Expand All @@ -251,6 +255,13 @@ async def list_models():
"data": models
})

@app.post("/v1/images/generations")
async def images_generations(
request: ImageGenerationRequest,
token: str = Depends(verify_api_key)
):
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")

@app.get("/generate-api-key")
def generate_api_key():
api_key = "sk-" + secrets.token_urlsafe(32)
Expand Down
7 changes: 7 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union

class ImageGenerationRequest(BaseModel):
model: str
prompt: str
n: int
size: str
stream: bool = False

class FunctionParameter(BaseModel):
type: str
properties: Dict[str, Dict[str, str]]
Expand Down
23 changes: 22 additions & 1 deletion request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from models import RequestModel
from utils import c35s, c3s, c3o, c3h, gem, CircularList
from utils import c35s, c3s, c3o, c3h, gem, BaseAPI

async def get_image_message(base64_image, engine = None):
if "gpt" == engine:
Expand Down Expand Up @@ -748,6 +748,25 @@ async def get_claude_payload(request, engine, provider):

return url, headers, payload

async def get_dalle_payload(request, engine, provider):
model = provider['model'][request.model]
headers = {
"Content-Type": "application/json",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {provider['api']}"
url = provider['base_url']
url = BaseAPI(url).image_url

payload = {
"model": model,
"prompt": request.prompt,
"n": request.n,
"size": request.size
}

return url, headers, payload

async def get_payload(request: RequestModel, engine, provider):
if engine == "gemini":
return await get_gemini_payload(request, engine, provider)
Expand All @@ -761,5 +780,7 @@ async def get_payload(request: RequestModel, engine, provider):
return await get_gpt_payload(request, engine, provider)
elif engine == "openrouter":
return await get_openrouter_payload(request, engine, provider)
elif engine == "dalle":
return await get_dalle_payload(request, engine, provider)
else:
raise ValueError("Unknown payload")
72 changes: 33 additions & 39 deletions response.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,24 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f

return sse_response

async def check_response(response, error_log):
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
return {"error": f"{error_log} HTTP Error {response.status_code}", "details": error_json}
return None

async def fetch_gemini_response_stream(client, url, headers, payload, model):
timestamp = datetime.timestamp(datetime.now())
async with client.stream('POST', url, headers=headers, json=payload) as response:
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
error_message = await check_response(response, "fetch_gemini_response_stream")
if error_message:
yield error_message
return
buffer = ""
revicing_function_call = False
function_full_response = "{"
Expand Down Expand Up @@ -87,14 +94,11 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
timestamp = datetime.timestamp(datetime.now())
async with client.stream('POST', url, headers=headers, json=payload) as response:
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
error_message = await check_response(response, "fetch_vertex_claude_response_stream")
if error_message:
yield error_message
return

buffer = ""
revicing_function_call = False
function_full_response = "{"
Expand Down Expand Up @@ -138,14 +142,9 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
while redirect_count < max_redirects:
# logger.info(f"fetch_gpt_response_stream: {url}")
async with client.stream('POST', url, headers=headers, json=payload) as response:
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
error_message = await check_response(response, "fetch_gpt_response_stream")
if error_message:
yield error_message
return

buffer = ""
Expand Down Expand Up @@ -185,14 +184,10 @@ async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects
async def fetch_claude_response_stream(client, url, headers, payload, model):
timestamp = datetime.timestamp(datetime.now())
async with client.stream('POST', url, headers=headers, json=payload) as response:
if response.status_code != 200:
error_message = await response.aread()
error_str = error_message.decode('utf-8', errors='replace')
try:
error_json = json.loads(error_str)
except json.JSONDecodeError:
error_json = error_str
yield {"error": f"fetch_claude_response_stream HTTP Error {response.status_code}", "details": error_json}
error_message = await check_response(response, "fetch_claude_response_stream")
if error_message:
yield error_message
return
buffer = ""
async for chunk in response.aiter_text():
# logger.info(f"chunk: {repr(chunk)}")
Expand Down Expand Up @@ -241,13 +236,12 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
yield sse_string

async def fetch_response(client, url, headers, payload):
try:
response = await client.post(url, headers=headers, json=payload)
return response.json()
except httpx.ConnectError as e:
return {"error": f"500", "details": "fetch_response Connect Error"}
except httpx.ReadTimeout as e:
return {"error": f"500", "details": "fetch_response Read Response Timeout"}
response = await client.post(url, headers=headers, json=payload)
error_message = await check_response(response, "fetch_response")
if error_message:
yield error_message
return
yield response.json()

async def fetch_response_stream(client, url, headers, payload, engine, model):
try:
Expand Down
28 changes: 27 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,30 @@ def next(self):
c3s = CircularList(["us-east5", "us-central1", "asia-southeast1"])
c3o = CircularList(["us-east5"])
c3h = CircularList(["us-east5", "us-central1", "europe-west1", "europe-west4"])
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])
gem = CircularList(["us-central1", "us-east4", "us-west1", "us-west4", "europe-west1", "europe-west2"])

class BaseAPI:
def __init__(
self,
api_url: str = "https://api.openai.com/v1/chat/completions",
):
if api_url == "":
api_url = "https://api.openai.com/v1/chat/completions"
self.source_api_url: str = api_url
from urllib.parse import urlparse, urlunparse
parsed_url = urlparse(self.source_api_url)
if parsed_url.scheme == "":
raise Exception("Error: API_URL is not set")
if parsed_url.path != '/':
before_v1 = parsed_url.path.split("/v1")[0]
else:
before_v1 = ""
self.base_url: str = urlunparse(parsed_url[:2] + (before_v1,) + ("",) * 3)
self.v1_url: str = urlunparse(parsed_url[:2]+ (before_v1 + "/v1",) + ("",) * 3)
self.v1_models: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/models",) + ("",) * 3)
if parsed_url.netloc == "api.deepseek.com":
self.chat_url: str = urlunparse(parsed_url[:2] + ("/chat/completions",) + ("",) * 3)
else:
self.chat_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/chat/completions",) + ("",) * 3)
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)

0 comments on commit 2ec0842

Please sign in to comment.