Skip to content

Commit

Permalink
fix: Split payload content by smaller batches for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Oct 15, 2024
1 parent 697d95e commit ec62ad4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 17 deletions.
52 changes: 35 additions & 17 deletions agents-api/agents_api/activities/embed_docs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
from itertools import batched

from beartype import beartype
from temporalio import activity

Expand All @@ -8,26 +11,41 @@


@beartype
async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None:
async def embed_docs(
payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100
) -> None:
indices, snippets = list(zip(*enumerate(payload.content)))
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
batched_indices, batched_snippets = (
batched(indices, max_batch_size),
batched(snippets, max_batch_size),
)

embed_snippets_query(
developer_id=payload.developer_id,
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or cozo.get_cozo_client(),
async def embed_batch(indices, snippets):
embed_instruction: str = payload.embed_instruction or ""
title: str = payload.title or ""

embeddings = await litellm.aembedding(
inputs=[
(
embed_instruction + (title + "\n\n" + snippet) if title else snippet
).strip()
for snippet in snippets
]
)

embed_snippets_query(
developer_id=payload.developer_id,
doc_id=payload.doc_id,
snippet_indices=indices,
embeddings=embeddings,
client=cozo_client or cozo.get_cozo_client(),
)

await asyncio.wait(
[
embed_batch(indices, snippets)
for indices, snippets in zip(batched_indices, batched_snippets)
]
)


Expand Down
29 changes: 29 additions & 0 deletions agents-api/tests/test_activities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from uuid import uuid4
from unittest.mock import patch

from ward import test

Expand Down Expand Up @@ -40,6 +41,34 @@ async def _(
)


@test("activity: call direct embed_docs with batching")
async def _(
cozo_client=cozo_client,
developer_id=test_developer_id,
doc=test_doc,
):
title = "title"
content = ["content 1", "content 2", "content 3", "content 4", "content 5"]
include_title = True

with patch("agents_api.activities.embed_docs.embed_snippets_query") as embed_query:
embed_query.return_value = None

await embed_docs(
EmbedDocsPayload(
developer_id=developer_id,
doc_id=doc.id,
title=title,
content=content,
include_title=include_title,
embed_instruction=None,
),
cozo_client,
max_batch_size=2,
)

embed_query.call_count == 3

@test("activity: call demo workflow via temporal client")
async def _():
async with patch_testing_temporal() as (_, mock_get_client):
Expand Down

0 comments on commit ec62ad4

Please sign in to comment.