Skip to content

Commit

Permalink
Send dimensions & user args to embedding models (#8090)
Browse files Browse the repository at this point in the history
If provided, send dimensions and user in embedding requests.
  • Loading branch information
diksipav authored Dec 13, 2024
1 parent d64473b commit fc5eeac
Showing 1 changed file with 39 additions and 8 deletions.
47 changes: 39 additions & 8 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class EmbeddingsParams(rs.Params[EmbeddingsData]):
inputs: list[tuple[PendingEmbedding, str]]
token_count: int
shortening: Optional[int]
user: Optional[str]

def costs(self) -> dict[str, int]:
return {
Expand All @@ -547,7 +548,8 @@ async def run(self) -> Optional[rs.Result[EmbeddingsData]]:
self.params.model_name,
[input[1] for input in self.params.inputs],
self.params.shortening,
self.params.http_client
self.params.user,
self.params.http_client,
)
result.pgconn = self.params.pgconn
result.pending_entries = [
Expand Down Expand Up @@ -757,6 +759,7 @@ async def _generate_embeddings_params(
inputs=inputs,
token_count=batch_token_count,
shortening=shortening,
user=None,
http_client=http_client,
))

Expand All @@ -774,6 +777,7 @@ async def _generate_embeddings_params(
inputs=inputs,
token_count=total_token_count,
shortening=shortening,
user=None,
http_client=http_client,
))

Expand Down Expand Up @@ -983,6 +987,7 @@ async def _generate_embeddings(
model_name: str,
inputs: list[str],
shortening: Optional[int],
user: Optional[str],
http_client: http.HttpClient,
) -> EmbeddingsResult:
task_name = _task_name.get()
Expand All @@ -995,7 +1000,7 @@ async def _generate_embeddings(

if provider.api_style == ApiStyle.OpenAI:
return await _generate_openai_embeddings(
provider, model_name, inputs, shortening, http_client
provider, model_name, inputs, shortening, user, http_client
)
else:
raise RuntimeError(
Expand All @@ -1009,6 +1014,7 @@ async def _generate_openai_embeddings(
model_name: str,
inputs: list[str],
shortening: Optional[int],
user: Optional[str],
http_client: http.HttpClient,
) -> EmbeddingsResult:

Expand All @@ -1030,6 +1036,9 @@ async def _generate_openai_embeddings(
if shortening is not None:
params["dimensions"] = shortening

if user is not None:
params["user"] = user

result = await client.post(
"/embeddings",
json=params,
Expand Down Expand Up @@ -2215,16 +2224,32 @@ async def _handle_embeddings_request(
raise TypeError(
'the body of the request must be a JSON object')

inputs = body.get("input")
inputs = body.get("inputs")
input = body.get("input")

if inputs is not None and input is not None:
raise TypeError(
"You cannot provide both 'inputs' and 'input'. "
"Please provide 'inputs'; 'input' has been deprecated."
)

if input is not None:
logger.warning("'input' is deprecated, use 'inputs' instead")
inputs = input

if not inputs:
raise TypeError(
'missing or empty required "input" value in request')
'missing or empty required "inputs" value in request'
)

model_name = body.get("model")
if not model_name:
raise TypeError(
'missing or empty required "model" value in request')

shortening = body.get("dimensions")
user = body.get("user")

except Exception as ex:
raise BadRequestError(str(ex)) from None

Expand All @@ -2246,8 +2271,9 @@ async def _handle_embeddings_request(
provider,
model_name,
inputs,
shortening=None,
http_client=tenant.get_http_client(originator="ai/embeddings")
shortening,
user,
http_client=tenant.get_http_client(originator="ai/embeddings"),
)
if isinstance(result.data, rs.Error):
raise AIProviderError(result.data.message)
Expand Down Expand Up @@ -2507,8 +2533,13 @@ async def _generate_embeddings_for_type(
else:
shortening = None
result = await _generate_embeddings(
provider, index["model"], [content], shortening=shortening,
http_client=http_client)
provider,
index["model"],
[content],
shortening,
None,
http_client,
)
if isinstance(result.data, rs.Error):
raise AIProviderError(result.data.message)
return result.data.embeddings

0 comments on commit fc5eeac

Please sign in to comment.