From fc5eeace5cc526b6ddce6a0a7d336d865a11841c Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Fri, 13 Dec 2024 23:19:06 +0100 Subject: [PATCH] Send dimensions & user args to embedding models (#8090) If provided, send dimensions and user in embedding requests. --- edb/server/protocol/ai_ext.py | 47 +++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index badeabe85a6..fcc200eb906 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -524,6 +524,7 @@ class EmbeddingsParams(rs.Params[EmbeddingsData]): inputs: list[tuple[PendingEmbedding, str]] token_count: int shortening: Optional[int] + user: Optional[str] def costs(self) -> dict[str, int]: return { @@ -547,7 +548,8 @@ async def run(self) -> Optional[rs.Result[EmbeddingsData]]: self.params.model_name, [input[1] for input in self.params.inputs], self.params.shortening, - self.params.http_client + self.params.user, + self.params.http_client, ) result.pgconn = self.params.pgconn result.pending_entries = [ @@ -757,6 +759,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=batch_token_count, shortening=shortening, + user=None, http_client=http_client, )) @@ -774,6 +777,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=total_token_count, shortening=shortening, + user=None, http_client=http_client, )) @@ -983,6 +987,7 @@ async def _generate_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: task_name = _task_name.get() @@ -995,7 +1000,7 @@ async def _generate_embeddings( if provider.api_style == ApiStyle.OpenAI: return await _generate_openai_embeddings( - provider, model_name, inputs, shortening, http_client + provider, model_name, inputs, shortening, user, http_client ) else: raise RuntimeError( @@ -1009,6 +1014,7 @@ async def _generate_openai_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: @@ -1030,6 +1036,9 @@ async def _generate_openai_embeddings( if shortening is not None: params["dimensions"] = shortening + if user is not None: + params["user"] = user + result = await client.post( "/embeddings", json=params, @@ -2215,16 +2224,32 @@ async def _handle_embeddings_request( raise TypeError( 'the body of the request must be a JSON object') - inputs = body.get("input") + inputs = body.get("inputs") + input = body.get("input") + + if inputs is not None and input is not None: + raise TypeError( + "You cannot provide both 'inputs' and 'input'. " + "Please provide 'inputs'; 'input' has been deprecated." + ) + + if input is not None: + logger.warning("'input' is deprecated, use 'inputs' instead") + inputs = input + if not inputs: raise TypeError( - 'missing or empty required "input" value in request') + 'missing or empty required "inputs" value in request' + ) model_name = body.get("model") if not model_name: raise TypeError( 'missing or empty required "model" value in request') + shortening = body.get("dimensions") + user = body.get("user") + except Exception as ex: raise BadRequestError(str(ex)) from None @@ -2246,8 +2271,9 @@ async def _handle_embeddings_request( provider, model_name, inputs, - shortening=None, - http_client=tenant.get_http_client(originator="ai/embeddings") + shortening, + user, + http_client=tenant.get_http_client(originator="ai/embeddings"), ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) @@ -2507,8 +2533,13 @@ async def _generate_embeddings_for_type( else: shortening = None result = await _generate_embeddings( - provider, index["model"], [content], shortening=shortening, - http_client=http_client) + provider, + index["model"], + [content], + shortening, + None, + http_client, + ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) return result.data.embeddings