Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send dimensions & user args to embedding models #8090

Merged
merged 7 commits into from
Dec 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ class EmbeddingsParams(rs.Params[EmbeddingsData]):
inputs: list[tuple[PendingEmbedding, str]]
token_count: int
shortening: Optional[int]
user: Optional[str]

def costs(self) -> dict[str, int]:
return {
Expand All @@ -547,7 +548,8 @@ async def run(self) -> Optional[rs.Result[EmbeddingsData]]:
self.params.model_name,
[input[1] for input in self.params.inputs],
self.params.shortening,
self.params.http_client
self.params.user,
self.params.http_client,
)
result.pgconn = self.params.pgconn
result.pending_entries = [
Expand Down Expand Up @@ -757,6 +759,7 @@ async def _generate_embeddings_params(
inputs=inputs,
token_count=batch_token_count,
shortening=shortening,
user=None,
http_client=http_client,
))

Expand All @@ -774,6 +777,7 @@ async def _generate_embeddings_params(
inputs=inputs,
token_count=total_token_count,
shortening=shortening,
user=None,
http_client=http_client,
))

Expand Down Expand Up @@ -983,6 +987,7 @@ async def _generate_embeddings(
model_name: str,
inputs: list[str],
shortening: Optional[int],
user: Optional[str],
http_client: http.HttpClient,
) -> EmbeddingsResult:
task_name = _task_name.get()
Expand All @@ -995,7 +1000,7 @@ async def _generate_embeddings(

if provider.api_style == ApiStyle.OpenAI:
return await _generate_openai_embeddings(
provider, model_name, inputs, shortening, http_client
provider, model_name, inputs, shortening, user, http_client
)
else:
raise RuntimeError(
Expand All @@ -1009,6 +1014,7 @@ async def _generate_openai_embeddings(
model_name: str,
inputs: list[str],
shortening: Optional[int],
user: Optional[str],
http_client: http.HttpClient,
) -> EmbeddingsResult:

Expand All @@ -1030,6 +1036,9 @@ async def _generate_openai_embeddings(
if shortening is not None:
params["dimensions"] = shortening

if user is not None:
params["user"] = user

result = await client.post(
"/embeddings",
json=params,
Expand Down Expand Up @@ -2215,16 +2224,32 @@ async def _handle_embeddings_request(
raise TypeError(
'the body of the request must be a JSON object')

inputs = body.get("input")
inputs = body.get("inputs")
input = body.get("input")

if inputs is not None and input is not None:
raise TypeError(
"You cannot provide both 'inputs' and 'input'. "
"Please provide 'inputs'; 'input' has been deprecated."
)

if input is not None:
logger.warning("'input' is deprecated, use 'inputs' instead")
inputs = input

if not inputs:
raise TypeError(
'missing or empty required "input" value in request')
'missing or empty required "inputs" value in request'
)

model_name = body.get("model")
if not model_name:
raise TypeError(
'missing or empty required "model" value in request')

shortening = body.get("dimensions")
Copy link
Contributor Author

@diksipav diksipav Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like above on line 2225 to use inputs = body.get("inputs") and to expect inputs from the clients, but this is a breaking change so probably we wanna leave it as input.

Is it ok to accept both input or inputs and to provide deprecated for input?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, accept both, but throw an error when both are specified.

user = body.get("user")

except Exception as ex:
raise BadRequestError(str(ex)) from None

Expand All @@ -2246,8 +2271,9 @@ async def _handle_embeddings_request(
provider,
model_name,
inputs,
shortening=None,
http_client=tenant.get_http_client(originator="ai/embeddings")
shortening,
user,
http_client=tenant.get_http_client(originator="ai/embeddings"),
)
if isinstance(result.data, rs.Error):
raise AIProviderError(result.data.message)
Expand Down Expand Up @@ -2507,8 +2533,13 @@ async def _generate_embeddings_for_type(
else:
shortening = None
result = await _generate_embeddings(
provider, index["model"], [content], shortening=shortening,
http_client=http_client)
provider,
index["model"],
[content],
shortening,
None,
http_client,
)
if isinstance(result.data, rs.Error):
raise AIProviderError(result.data.message)
return result.data.embeddings
Loading