Skip to content

Commit

Permalink
fix: Fix search by embedding
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 22, 2024
1 parent 4fc4f0e commit e2181fb
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 28 deletions.
65 changes: 38 additions & 27 deletions agents-api/agents_api/queries/docs/search_docs_by_embedding.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,23 @@
from typing import List, Literal
from typing import Any, List, Literal
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import DocReference
from ..utils import pg_query, wrap_in_class

# If you're doing approximate ANN (DiskANN) or IVF, you might use a special function or hint.
# For a basic vector distance search, you can do something like:
search_docs_by_embedding_query = parse_one("""
SELECT d.*,
(d.embedding <-> $3) AS distance
FROM docs d
LEFT JOIN doc_owners do
ON d.developer_id = do.developer_id
AND d.doc_id = do.doc_id
WHERE d.developer_id = $1
AND (
($4::text IS NULL AND $5::uuid IS NULL)
OR (do.owner_type = $4 AND do.owner_id = $5)
)
AND d.embedding IS NOT NULL
ORDER BY d.embedding <-> $3
LIMIT $2;
""").sql(pretty=True)
search_docs_by_embedding_query = """
SELECT * FROM search_by_vector(
$1, -- developer_id
$2::vector(1024), -- query_embedding
$3::text[], -- owner_types
$UUID_LIST::uuid[], -- owner_ids
$4, -- k
$5, -- confidence
$6 -- metadata_filter
)
"""


@wrap_in_class(
Expand All @@ -46,8 +38,9 @@ async def search_docs_by_embedding(
developer_id: UUID,
query_embedding: List[float],
k: int = 10,
owner_type: Literal["user", "agent", "org"] | None = None,
owner_id: UUID | None = None,
owners: list[tuple[Literal["user", "agent"], UUID]],
confidence: float = 0.5,
metadata_filter: dict[str, Any] = {},
) -> tuple[str, list]:
"""
Vector-based doc search:
Expand All @@ -56,20 +49,38 @@ async def search_docs_by_embedding(
developer_id (UUID): The ID of the developer.
query_embedding (List[float]): The vector to query.
k (int): The number of results to return.
owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents.
owner_id (UUID): The ID of the owner of the documents.
owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples.
confidence (float): The confidence threshold for the search.
metadata_filter (dict): Metadata filter criteria.
Returns:
tuple[str, list]: SQL query and parameters for searching the documents.
"""
if k < 1:
raise HTTPException(status_code=400, detail="k must be >= 1")

# Validate embedding length if needed; e.g. 1024 floats
if not query_embedding:
raise HTTPException(status_code=400, detail="Empty embedding provided")

# Convert query_embedding to a string
query_embedding_str = f"[{', '.join(map(str, query_embedding))}]"

# Extract owner types and IDs
owner_types: list[str] = [owner[0] for owner in owners]
owner_ids: list[str] = [str(owner[1]) for owner in owners]

# NOTE: Manually replace uuids list coz asyncpg isnt sending it correctly
owner_ids_pg_str = f"ARRAY['{'\', \''.join(owner_ids)}']"
query = search_docs_by_embedding_query.replace("$UUID_LIST", owner_ids_pg_str)

return (
search_docs_by_embedding_query,
[developer_id, k, query_embedding, owner_type, owner_id],
query,
[
developer_id,
query_embedding_str,
owner_types,
k,
confidence,
metadata_filter,
],
)
34 changes: 33 additions & 1 deletion agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from agents_api.queries.docs.delete_doc import delete_doc
from agents_api.queries.docs.get_doc import get_doc
from agents_api.queries.docs.list_docs import list_docs
from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding
from agents_api.queries.docs.search_docs_by_text import search_docs_by_text

# from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding
# from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid
from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user

Expand Down Expand Up @@ -243,3 +243,35 @@ async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):

assert len(result) >= 1
assert result[0].metadata is not None


@test("query: search docs by embedding")
async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer):
pool = await create_db_pool(dsn=dsn)

# Create a test document
await create_doc(
developer_id=developer.id,
owner_type="agent",
owner_id=agent.id,
data=CreateDocRequest(
title="Hello",
content="The world is a funny little thing",
metadata={"test": "test"},
embed_instruction="Embed the document",
),
connection_pool=pool,
)

# Search using the correct parameter types
result = await search_docs_by_embedding(
developer_id=developer.id,
owners=[("agent", agent.id)],
query_embedding=[1.0]*1024,
k=3, # Add k parameter
metadata_filter={"test": "test"}, # Add metadata filter
connection_pool=pool,
)

assert len(result) >= 1
assert result[0].metadata is not None

0 comments on commit e2181fb

Please sign in to comment.