From 5de3f37ed02da1ddeb8a7c2f8bc1ac3a85e34ac3 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Tue, 10 Dec 2024 09:37:09 +0100 Subject: [PATCH 1/7] Send dimensions & user args to embedding models --- edb/server/protocol/ai_ext.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index badeabe85a6..8ec2d11adf0 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -524,6 +524,7 @@ class EmbeddingsParams(rs.Params[EmbeddingsData]): inputs: list[tuple[PendingEmbedding, str]] token_count: int shortening: Optional[int] + user: Optional[str] def costs(self) -> dict[str, int]: return { @@ -547,7 +548,8 @@ async def run(self) -> Optional[rs.Result[EmbeddingsData]]: self.params.model_name, [input[1] for input in self.params.inputs], self.params.shortening, - self.params.http_client + self.params.user, + self.params.http_client, ) result.pgconn = self.params.pgconn result.pending_entries = [ @@ -983,6 +985,7 @@ async def _generate_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: task_name = _task_name.get() @@ -995,7 +998,7 @@ async def _generate_embeddings( if provider.api_style == ApiStyle.OpenAI: return await _generate_openai_embeddings( - provider, model_name, inputs, shortening, http_client + provider, model_name, inputs, shortening, user, http_client ) else: raise RuntimeError( @@ -1009,6 +1012,7 @@ async def _generate_openai_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], + user: Optional[int], http_client: http.HttpClient, ) -> EmbeddingsResult: @@ -1030,6 +1034,9 @@ async def _generate_openai_embeddings( if shortening is not None: params["dimensions"] = shortening + if user is not None: + params["user"] = user + result = await client.post( "/embeddings", json=params, @@ -2225,6 +2232,9 @@ async def _handle_embeddings_request( raise TypeError( 'missing or empty required "model" value in request') + shortening = body.get("dimensions") + user = body.get("user") + except Exception as ex: raise BadRequestError(str(ex)) from None @@ -2246,8 +2256,9 @@ async def _handle_embeddings_request( provider, model_name, inputs, - shortening=None, - http_client=tenant.get_http_client(originator="ai/embeddings") + shortening, + user, + http_client=tenant.get_http_client(originator="ai/embeddings"), ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) @@ -2507,8 +2518,13 @@ async def _generate_embeddings_for_type( else: shortening = None result = await _generate_embeddings( - provider, index["model"], [content], shortening=shortening, - http_client=http_client) + provider, + index["model"], + [content], + shortening, + None, + http_client, + ) if isinstance(result.data, rs.Error): raise AIProviderError(result.data.message) return result.data.embeddings From 53f5fcd5ff1aaa236aec0322144fc735ba3ca491 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Tue, 10 Dec 2024 17:51:39 +0100 Subject: [PATCH 2/7] Fix user type in -generate_embeddings --- edb/server/protocol/ai_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index 8ec2d11adf0..c6201db03ac 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -1012,7 +1012,7 @@ async def _generate_openai_embeddings( model_name: str, inputs: list[str], shortening: Optional[int], - user: Optional[int], + user: Optional[str], http_client: http.HttpClient, ) -> EmbeddingsResult: From 07e22d58ad2b9ae0bec9816ab88242fde0ae5fd6 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 11 Dec 2024 09:48:12 +0100 Subject: [PATCH 3/7] Accept both input & inputs in the ai ext --- edb/server/protocol/ai_ext.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index c6201db03ac..c2269513d78 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -2222,10 +2222,26 @@ async def _handle_embeddings_request( raise TypeError( 'the body of the request must be a JSON object') - inputs = body.get("input") - if not inputs: + inputs = body.get("inputs") + input_value = body.get("input") + + if inputs is not None and input_value is not None: raise TypeError( - 'missing or empty required "input" value in request') + "You can not provide both 'inputs' and 'input'. Please provide 'inputs', 'input' has been deprecated. " + ) + + if inputs is not None: + return inputs + elif input_value is not None: + print("Warning: 'input' is deprecated. Use 'inputs' instead.") + return input_value + else: + raise TypeError("Neither 'inputs' nor 'input' is provided.") + + # inputs = body.get("input") + # if not inputs: + # raise TypeError( + # 'missing or empty required "input" value in request') model_name = body.get("model") if not model_name: From 300422587d1c9370e377f4f0738e87827a3241e8 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 11 Dec 2024 09:50:58 +0100 Subject: [PATCH 4/7] Fix formatting issues --- edb/server/protocol/ai_ext.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index c2269513d78..4463ace1811 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -2223,25 +2223,22 @@ async def _handle_embeddings_request( 'the body of the request must be a JSON object') inputs = body.get("inputs") - input_value = body.get("input") + input = body.get("input") - if inputs is not None and input_value is not None: + if inputs is not None and input is not None: raise TypeError( - "You can not provide both 'inputs' and 'input'. Please provide 'inputs', 'input' has been deprecated. " + "You cannot provide both 'inputs' and 'input'. " + "Please provide 'inputs'; 'input' has been deprecated." ) - if inputs is not None: - return inputs - elif input_value is not None: + if input is not None: print("Warning: 'input' is deprecated. Use 'inputs' instead.") - return input_value - else: - raise TypeError("Neither 'inputs' nor 'input' is provided.") + inputs = input - # inputs = body.get("input") - # if not inputs: - # raise TypeError( - # 'missing or empty required "input" value in request') + if not inputs: + raise TypeError( + 'missing or empty required "inputs" value in request' + ) model_name = body.get("model") if not model_name: From 639b82a466248eadeb0447884681d5d4184b25c1 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 11 Dec 2024 11:32:17 +0100 Subject: [PATCH 5/7] ai ext: add logger warning when client uses input --- edb/server/protocol/ai_ext.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index 4463ace1811..58e4ac2a008 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -781,7 +781,6 @@ async def _generate_embeddings_params( return embeddings_params - @dataclass(frozen=True, kw_only=True) class PendingEmbedding: id: uuid.UUID @@ -2232,7 +2231,7 @@ async def _handle_embeddings_request( ) if input is not None: - print("Warning: 'input' is deprecated. Use 'inputs' instead.") + logger.warning("'input' is deprecated, use 'inputs' instead") inputs = input if not inputs: From 92e12d53c3099129c4e3a504b30bc0d5056b9ece Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 11 Dec 2024 16:23:56 +0100 Subject: [PATCH 6/7] Fix formatting at the top l evel --- edb/server/protocol/ai_ext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index 58e4ac2a008..137b8c1aa96 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -781,6 +781,7 @@ async def _generate_embeddings_params( return embeddings_params + @dataclass(frozen=True, kw_only=True) class PendingEmbedding: id: uuid.UUID From 3de2838d225d320b8dcfa785bf55c793f067449b Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 11 Dec 2024 16:26:46 +0100 Subject: [PATCH 7/7] Add user to embedding params --- edb/server/protocol/ai_ext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index 137b8c1aa96..fcc200eb906 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -759,6 +759,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=batch_token_count, shortening=shortening, + user=None, http_client=http_client, )) @@ -776,6 +777,7 @@ async def _generate_embeddings_params( inputs=inputs, token_count=total_token_count, shortening=shortening, + user=None, http_client=http_client, ))