Skip to content

Commit

Permalink
✨ Feature: Add support for embeddings model
Browse files Browse the repository at this point in the history
  • Loading branch information
yym68686 committed Oct 24, 2024
1 parent d91f3fa commit c50b8cc
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
17 changes: 14 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi.exceptions import RequestValidationError

from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest
from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest
from request import get_payload
from response import fetch_response, fetch_response_stream
from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml
Expand Down Expand Up @@ -478,7 +478,7 @@ async def ensure_config(request: Request, call_next):
return await call_next(request)

# 在 process_request 函数中更新成功和失败计数
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None):
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None):
url = provider['base_url']
parsed_url = urlparse(url)
# print("parsed_url", parsed_url)
Expand Down Expand Up @@ -529,6 +529,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
engine = "moderation"
request.stream = False

if endpoint == "/v1/embeddings":
engine = "embedding"
request.stream = False

if provider.get("engine"):
engine = provider["engine"]

Expand Down Expand Up @@ -700,7 +704,7 @@ def get_matching_providers(self, model_name, token):
# print("provider_list", provider_list)
return provider_list

async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None):
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
config = app.state.config
api_list = app.state.api_list
api_index = api_list.index(token)
Expand Down Expand Up @@ -904,6 +908,13 @@ async def images_generations(
):
return await model_handler.request_model(request, token, endpoint="/v1/images/generations")

@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
async def embeddings(
request: EmbeddingRequest,
token: str = Depends(verify_api_key)
):
return await model_handler.request_model(request, token, endpoint="/v1/embeddings")

@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
async def moderations(
request: ModerationRequest,
Expand Down
11 changes: 10 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class ImageGenerationRequest(BaseRequest):
size: Optional[str] = "1024x1024"
stream: bool = False

class EmbeddingRequest(BaseRequest):
input: str
model: str
encoding_format: Optional[str] = "float"
stream: bool = False

class AudioTranscriptionRequest(BaseRequest):
file: Tuple[str, IOBase, str]
model: str
Expand All @@ -129,7 +135,7 @@ class ModerationRequest(BaseRequest):
stream: bool = False

class UnifiedRequest(BaseModel):
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest]
data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest]

@model_validator(mode='before')
@classmethod
Expand All @@ -147,6 +153,9 @@ def set_request_type(cls, values):
elif "input" in values:
values["data"] = ModerationRequest(**values)
values["data"].request_type = "moderation"
elif "input" in values:
values["data"] = EmbeddingRequest(**values)
values["data"].request_type = "embedding"
else:
raise ValueError("无法确定请求类型")
return values
23 changes: 23 additions & 0 deletions request.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,27 @@ async def get_moderation_payload(request, engine, provider):

return url, headers, payload

async def get_embedding_payload(request, engine, provider):
model_dict = get_model_dict(provider)
model = model_dict[request.model]
headers = {
"Content-Type": "application/json",
}
if provider.get("api"):
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
url = provider['base_url']
url = BaseAPI(url).embeddings

payload = {
"input": request.input,
"model": model,
}

if request.encoding_format:
payload["encoding_format"] = request.encoding_format

return url, headers, payload

async def get_payload(request: RequestModel, engine, provider):
if engine == "gemini":
return await get_gemini_payload(request, engine, provider)
Expand All @@ -1150,5 +1171,7 @@ async def get_payload(request: RequestModel, engine, provider):
return await get_whisper_payload(request, engine, provider)
elif engine == "moderation":
return await get_moderation_payload(request, engine, provider)
elif engine == "embedding":
return await get_embedding_payload(request, engine, provider)
else:
raise ValueError("Unknown payload")
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def __init__(
self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3)
self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3)
self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3)
self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3)

def safe_get(data, *keys, default=None):
for key in keys:
Expand Down

0 comments on commit c50b8cc

Please sign in to comment.