diff --git a/main.py b/main.py index 22ac3a7..eddbbf8 100644 --- a/main.py +++ b/main.py @@ -16,7 +16,7 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.exceptions import RequestValidationError -from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest +from models import RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, UnifiedRequest, EmbeddingRequest from request import get_payload from response import fetch_response, fetch_response_stream from utils import error_handling_wrapper, post_all_models, load_config, safe_get, circular_list_encoder, get_model_dict, save_api_yaml @@ -478,7 +478,7 @@ async def ensure_config(request: Request, call_next): return await call_next(request) # 在 process_request 函数中更新成功和失败计数 -async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], provider: Dict, endpoint=None, token=None): +async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None): url = provider['base_url'] parsed_url = urlparse(url) # print("parsed_url", parsed_url) @@ -529,6 +529,10 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A engine = "moderation" request.stream = False + if endpoint == "/v1/embeddings": + engine = "embedding" + request.stream = False + if provider.get("engine"): engine = provider["engine"] @@ -700,7 +704,7 @@ def get_matching_providers(self, model_name, token): # print("provider_list", provider_list) return provider_list - async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest], token: str, endpoint=None): + async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None): config = app.state.config api_list = app.state.api_list api_index = api_list.index(token) @@ -904,6 +908,13 @@ async def images_generations( ): return await model_handler.request_model(request, token, endpoint="/v1/images/generations") +@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)]) +async def embeddings( + request: EmbeddingRequest, + token: str = Depends(verify_api_key) +): + return await model_handler.request_model(request, token, endpoint="/v1/embeddings") + @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)]) async def moderations( request: ModerationRequest, diff --git a/models.py b/models.py index 26842ad..c8d1777 100644 --- a/models.py +++ b/models.py @@ -111,6 +111,12 @@ class ImageGenerationRequest(BaseRequest): size: Optional[str] = "1024x1024" stream: bool = False +class EmbeddingRequest(BaseRequest): + input: str + model: str + encoding_format: Optional[str] = "float" + stream: bool = False + class AudioTranscriptionRequest(BaseRequest): file: Tuple[str, IOBase, str] model: str @@ -129,7 +135,7 @@ class ModerationRequest(BaseRequest): stream: bool = False class UnifiedRequest(BaseModel): - data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest] + data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest] @model_validator(mode='before') @classmethod @@ -147,6 +153,9 @@ def set_request_type(cls, values): elif "input" in values: values["data"] = ModerationRequest(**values) values["data"].request_type = "moderation" + elif "input" in values: + values["data"] = EmbeddingRequest(**values) + values["data"].request_type = "embedding" else: raise ValueError("无法确定请求类型") return values \ No newline at end of file diff --git a/request.py b/request.py index a2f371e..241c234 100644 --- a/request.py +++ b/request.py @@ -1125,6 +1125,27 @@ async def get_moderation_payload(request, engine, provider): return url, headers, payload +async def get_embedding_payload(request, engine, provider): + model_dict = get_model_dict(provider) + model = model_dict[request.model] + headers = { + "Content-Type": "application/json", + } + if provider.get("api"): + headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}" + url = provider['base_url'] + url = BaseAPI(url).embeddings + + payload = { + "input": request.input, + "model": model, + } + + if request.encoding_format: + payload["encoding_format"] = request.encoding_format + + return url, headers, payload + async def get_payload(request: RequestModel, engine, provider): if engine == "gemini": return await get_gemini_payload(request, engine, provider) @@ -1150,5 +1171,7 @@ async def get_payload(request: RequestModel, engine, provider): return await get_whisper_payload(request, engine, provider) elif engine == "moderation": return await get_moderation_payload(request, engine, provider) + elif engine == "embedding": + return await get_embedding_payload(request, engine, provider) else: raise ValueError("Unknown payload") \ No newline at end of file diff --git a/utils.py b/utils.py index 6bcac08..371b8a0 100644 --- a/utils.py +++ b/utils.py @@ -377,6 +377,7 @@ def __init__( self.image_url: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/images/generations",) + ("",) * 3) self.audio_transcriptions: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/audio/transcriptions",) + ("",) * 3) self.moderations: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/moderations",) + ("",) * 3) + self.embeddings: str = urlunparse(parsed_url[:2] + (before_v1 + "/v1/embeddings",) + ("",) * 3) def safe_get(data, *keys, default=None): for key in keys: