From 249513d6c944f77ff579cb4cd7e51b362483178f Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Sat, 21 Dec 2024 03:12:06 -0500 Subject: [PATCH] chore: updated migrations + added indices support --- .../queries/developers/get_developer.py | 9 +- .../agents_api/queries/docs/__init__.py | 6 +- .../agents_api/queries/docs/create_doc.py | 141 +++++++++++++----- .../agents_api/queries/docs/delete_doc.py | 24 ++- .../agents_api/queries/docs/embed_snippets.py | 37 ----- agents-api/agents_api/queries/docs/get_doc.py | 68 +++++---- .../agents_api/queries/docs/list_docs.py | 96 ++++++++---- .../queries/docs/search_docs_by_embedding.py | 4 - .../queries/docs/search_docs_by_text.py | 76 ++++++---- .../queries/docs/search_docs_hybrid.py | 5 - agents-api/tests/fixtures.py | 23 +-- agents-api/tests/test_docs_queries.py | 72 ++++----- agents-api/tests/test_files_queries.py | 2 +- memory-store/migrations/000006_docs.up.sql | 9 +- .../migrations/000018_doc_search.up.sql | 57 +++---- 15 files changed, 349 insertions(+), 280 deletions(-) delete mode 100644 agents-api/agents_api/queries/docs/embed_snippets.py diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 373a2fb36..79b6e6067 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -24,9 +24,6 @@ SELECT * FROM developers WHERE developer_id = $1 -- developer_id """).sql(pretty=True) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - @rewrap_exceptions( { @@ -37,7 +34,11 @@ ) } ) -@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]}) +@wrap_in_class( + Developer, + one=True, + transform=lambda d: {**d, "id": d["developer_id"]}, +) @pg_query @beartype async def get_developer( diff --git a/agents-api/agents_api/queries/docs/__init__.py b/agents-api/agents_api/queries/docs/__init__.py index 75f9516a6..51bab2555 100644 --- a/agents-api/agents_api/queries/docs/__init__.py +++ b/agents-api/agents_api/queries/docs/__init__.py @@ -8,6 +8,7 @@ - Listing documents based on various criteria, including ownership and metadata filters. - Deleting documents by their unique identifiers. - Embedding document snippets for retrieval purposes. +- Searching documents by text. The module interacts with other parts of the application, such as the agents and users modules, to provide a comprehensive document management system. Its role is crucial in enabling document search, retrieval, and management features within the context of agents and users. @@ -22,12 +23,13 @@ from .list_docs import list_docs # from .search_docs_by_embedding import search_docs_by_embedding -# from .search_docs_by_text import search_docs_by_text +from .search_docs_by_text import search_docs_by_text __all__ = [ "create_doc", "delete_doc", "get_doc", "list_docs", - # "search_docs_by_embct", + # "search_docs_by_embedding", + "search_docs_by_text", ] diff --git a/agents-api/agents_api/queries/docs/create_doc.py b/agents-api/agents_api/queries/docs/create_doc.py index 59fd40004..d8bcce7d3 100644 --- a/agents-api/agents_api/queries/docs/create_doc.py +++ b/agents-api/agents_api/queries/docs/create_doc.py @@ -47,15 +47,38 @@ INSERT INTO doc_owners ( developer_id, doc_id, + index, owner_type, owner_id ) - VALUES ($1, $2, $3, $4) + VALUES ($1, $2, $3, $4, $5) RETURNING doc_id ) -SELECT d.* +SELECT DISTINCT ON (docs.doc_id) + docs.doc_id, + docs.developer_id, + docs.title, + array_agg(docs.content ORDER BY docs.index) as content, + array_agg(docs.index ORDER BY docs.index) as indices, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at + FROM inserted_owner io -JOIN docs d ON d.doc_id = io.doc_id; +JOIN docs ON docs.doc_id = io.doc_id +GROUP BY + docs.doc_id, + docs.developer_id, + docs.title, + docs.modality, + docs.embedding_model, + docs.embedding_dimensions, + docs.language, + docs.metadata, + docs.created_at; """).sql(pretty=True) @@ -82,11 +105,10 @@ Doc, one=True, transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + **d, }, ) @increase_counter("create_doc") @@ -97,56 +119,99 @@ async def create_doc( developer_id: UUID, doc_id: UUID | None = None, data: CreateDocRequest, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, + owner_type: Literal["user", "agent"], + owner_id: UUID, modality: Literal["text", "image", "mixed"] | None = "text", embedding_model: str | None = "voyage-3", embedding_dimensions: int | None = 1024, language: str | None = "english", index: int | None = 0, -) -> list[tuple[str, list] | tuple[str, list, str]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: """ - Insert a new doc record into Timescale and optionally associate it with an owner. + Insert a new doc record into Timescale and associate it with an owner. Parameters: - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. + developer_id (UUID): The ID of the developer. + doc_id (UUID | None): Optional custom UUID for the document. If not provided, one will be generated. + data (CreateDocRequest): The data for the document. + owner_type (Literal["user", "agent"]): The type of the owner (required). + owner_id (UUID): The ID of the owner (required). modality (Literal["text", "image", "mixed"]): The modality of the documents. embedding_model (str): The model used for embedding. embedding_dimensions (int): The dimensions of the embedding. language (str): The language of the documents. index (int): The index of the documents. - data (CreateDocRequest): The data for the document. Returns: list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document. """ + queries = [] # Generate a UUID if not provided - doc_id = doc_id or uuid7() + current_doc_id = uuid7() if doc_id is None else doc_id - # check if content is a string - if isinstance(data.content, str): - data.content = [data.content] + # Check if content is a list + if isinstance(data.content, list): + final_params_doc = [] + final_params_owner = [] + + for idx, content in enumerate(data.content): + doc_params = [ + developer_id, + current_doc_id, + data.title, + content, + idx, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + final_params_doc.append(doc_params) - # Create the doc record - doc_params = [ - developer_id, - doc_id, - data.title, - str(data.content), - index, - modality, - embedding_model, - embedding_dimensions, - language, - data.metadata or {}, - ] - - queries = [(doc_query, doc_params)] - - # If an owner is specified, associate it: - if owner_type and owner_id: - owner_params = [developer_id, doc_id, owner_type, owner_id] - queries.append((doc_owner_query, owner_params)) + owner_params = [ + developer_id, + current_doc_id, + idx, + owner_type, + owner_id, + ] + final_params_owner.append(owner_params) + + # Add the doc query for each content + queries.append((doc_query, final_params_doc, "fetchmany")) + + # Add the owner query + queries.append((doc_owner_query, final_params_owner, "fetchmany")) + + else: + + # Create the doc record + doc_params = [ + developer_id, + current_doc_id, + data.title, + data.content, + index, + modality, + embedding_model, + embedding_dimensions, + language, + data.metadata or {}, + ] + + owner_params = [ + developer_id, + current_doc_id, + index, + owner_type, + owner_id, + ] + + # Add the doc query for single content + queries.append((doc_query, doc_params, "fetch")) + + # Add the owner query + queries.append((doc_owner_query, owner_params, "fetch")) return queries diff --git a/agents-api/agents_api/queries/docs/delete_doc.py b/agents-api/agents_api/queries/docs/delete_doc.py index 5697ca8d6..b0a9ea1a1 100644 --- a/agents-api/agents_api/queries/docs/delete_doc.py +++ b/agents-api/agents_api/queries/docs/delete_doc.py @@ -16,22 +16,18 @@ DELETE FROM doc_owners WHERE developer_id = $1 AND doc_id = $2 - AND ( - ($3::text IS NULL AND $4::uuid IS NULL) - OR (owner_type = $3 AND owner_id = $4) - ) + AND owner_type = $3 + AND owner_id = $4 ) DELETE FROM docs WHERE developer_id = $1 AND doc_id = $2 - AND ( - $3::text IS NULL OR EXISTS ( - SELECT 1 FROM doc_owners - WHERE developer_id = $1 - AND doc_id = $2 - AND owner_type = $3 - AND owner_id = $4 - ) + AND EXISTS ( + SELECT 1 FROM doc_owners + WHERE developer_id = $1 + AND doc_id = $2 + AND owner_type = $3 + AND owner_id = $4 ) RETURNING doc_id; """).sql(pretty=True) @@ -61,8 +57,8 @@ async def delete_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, + owner_type: Literal["user", "agent"], + owner_id: UUID, ) -> tuple[str, list]: """ Deletes a doc (and associated doc_owners) for the given developer and doc_id. diff --git a/agents-api/agents_api/queries/docs/embed_snippets.py b/agents-api/agents_api/queries/docs/embed_snippets.py deleted file mode 100644 index 1a20d6a34..000000000 --- a/agents-api/agents_api/queries/docs/embed_snippets.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Literal -from uuid import UUID - -from beartype import beartype -from sqlglot import parse_one - -from ..utils import pg_query - -# TODO: This is a placeholder for the actual query -vectorizer_query = None - - -@pg_query -@beartype -async def embed_snippets( - *, - developer_id: UUID, - doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, -) -> tuple[str, list]: - """ - Trigger the vectorizer to generate embeddings for documents. - - Parameters: - developer_id (UUID): The ID of the developer. - doc_id (UUID): The ID of the document. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. - - Returns: - tuple[str, list]: SQL query and parameters for embedding the snippets. - """ - return ( - vectorizer_query, - [developer_id, doc_id, owner_type, owner_id], - ) diff --git a/agents-api/agents_api/queries/docs/get_doc.py b/agents-api/agents_api/queries/docs/get_doc.py index 8575f77b0..3f071cf87 100644 --- a/agents-api/agents_api/queries/docs/get_doc.py +++ b/agents-api/agents_api/queries/docs/get_doc.py @@ -8,35 +8,51 @@ from ...autogen.openapi_model import Doc from ..utils import pg_query, wrap_in_class -# Combined query to fetch document details and embedding +# Update the query to use DISTINCT ON to prevent duplicates doc_with_embedding_query = parse_one(""" -SELECT d.*, e.embedding -FROM docs d -LEFT JOIN doc_owners doc_own - ON d.developer_id = doc_own.developer_id - AND d.doc_id = doc_own.doc_id -LEFT JOIN docs_embeddings e - ON d.doc_id = e.doc_id -WHERE d.developer_id = $1 - AND d.doc_id = $2 - AND ( - ($3::text IS NULL AND $4::uuid IS NULL) - OR (doc_own.owner_type = $3 AND doc_own.owner_id = $4) - ) -LIMIT 1; +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(e.embedding ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND d.doc_id = $2 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data; """).sql(pretty=True) @wrap_in_class( Doc, - one=True, + one=True, # Changed to True since we're now returning one grouped record transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), - "embedding": d["embedding"], # Add embedding to the transformation + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embeddings": d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"], + **d, }, ) @pg_query @@ -45,22 +61,18 @@ async def get_doc( *, developer_id: UUID, doc_id: UUID, - owner_type: Literal["user", "agent"] | None = None, - owner_id: UUID | None = None, ) -> tuple[str, list]: """ - Fetch a single doc with its embedding, optionally constrained to a given owner. - + Fetch a single doc with its embedding, grouping all content chunks and embeddings. + Parameters: developer_id (UUID): The ID of the developer. doc_id (UUID): The ID of the document. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. Returns: tuple[str, list]: SQL query and parameters for fetching the document. """ return ( doc_with_embedding_query, - [developer_id, doc_id, owner_type, owner_id], + [developer_id, doc_id], ) diff --git a/agents-api/agents_api/queries/docs/list_docs.py b/agents-api/agents_api/queries/docs/list_docs.py index bfbc2971e..2b31df250 100644 --- a/agents-api/agents_api/queries/docs/list_docs.py +++ b/agents-api/agents_api/queries/docs/list_docs.py @@ -1,34 +1,82 @@ -import ast +""" +This module contains the functionality for listing documents from the PostgreSQL database. +It constructs and executes SQL queries to fetch document details based on various filters. +""" + from typing import Any, Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import Doc -from ..utils import pg_query, wrap_in_class +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Base query for listing docs with optional embeddings +# Base query for listing docs with aggregated content and embeddings base_docs_query = parse_one(""" -SELECT d.*, CASE WHEN $2 THEN NULL ELSE e.embedding END AS embedding -FROM docs d -LEFT JOIN doc_owners doc_own ON d.developer_id = doc_own.developer_id AND d.doc_id = doc_own.doc_id -LEFT JOIN docs_embeddings e ON d.doc_id = e.doc_id -WHERE d.developer_id = $1 +WITH doc_data AS ( + SELECT DISTINCT ON (d.doc_id) + d.doc_id, + d.developer_id, + d.title, + array_agg(d.content ORDER BY d.index) as content, + array_agg(d.index ORDER BY d.index) as indices, + array_agg(CASE WHEN $2 THEN NULL ELSE e.embedding END ORDER BY d.index) as embeddings, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at + FROM docs d + JOIN doc_owners doc_own + ON d.developer_id = doc_own.developer_id + AND d.doc_id = doc_own.doc_id + LEFT JOIN docs_embeddings e + ON d.doc_id = e.doc_id + WHERE d.developer_id = $1 + AND doc_own.owner_type = $3 + AND doc_own.owner_id = $4 + GROUP BY + d.doc_id, + d.developer_id, + d.title, + d.modality, + d.embedding_model, + d.embedding_dimensions, + d.language, + d.metadata, + d.created_at +) +SELECT * FROM doc_data """).sql(pretty=True) +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No documents found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or owner does not exist", + ), + } +) @wrap_in_class( Doc, one=False, transform=lambda d: { - **d, "id": d["doc_id"], - "content": ast.literal_eval(d["content"])[0] - if len(ast.literal_eval(d["content"])) == 1 - else ast.literal_eval(d["content"]), - "embedding": d.get("embedding"), # Add embedding to the transformation + "index": d["indices"][0], + "content": d["content"][0] if len(d["content"]) == 1 else d["content"], + "embedding": d["embeddings"][0] if d.get("embeddings") and len(d["embeddings"]) == 1 else d.get("embeddings"), + **d, }, ) @pg_query @@ -36,8 +84,8 @@ async def list_docs( *, developer_id: UUID, - owner_id: UUID | None = None, - owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID, + owner_type: Literal["user", "agent"], limit: int = 100, offset: int = 0, sort_by: Literal["created_at", "updated_at"] = "created_at", @@ -46,12 +94,12 @@ async def list_docs( include_without_embeddings: bool = False, ) -> tuple[str, list]: """ - Lists docs with optional owner filtering, pagination, and sorting. + Lists docs with pagination and sorting, aggregating content chunks and embeddings. Parameters: developer_id (UUID): The ID of the developer. - owner_id (UUID): The ID of the owner of the documents. - owner_type (Literal["user", "agent"]): The type of the owner of the documents. + owner_id (UUID): The ID of the owner of the documents (required). + owner_type (Literal["user", "agent"]): The type of the owner of the documents (required). limit (int): The number of documents to return. offset (int): The number of documents to skip. sort_by (Literal["created_at", "updated_at"]): The field to sort by. @@ -61,6 +109,9 @@ async def list_docs( Returns: tuple[str, list]: SQL query and parameters for listing the documents. + + Raises: + HTTPException: If invalid parameters are provided. """ if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") @@ -76,17 +127,12 @@ async def list_docs( # Start with the base query query = base_docs_query - params = [developer_id, include_without_embeddings] - - # Add owner filtering - if owner_type and owner_id: - query += " AND doc_own.owner_type = $3 AND doc_own.owner_id = $4" - params.extend([owner_type, owner_id]) + params = [developer_id, include_without_embeddings, owner_type, owner_id] # Add metadata filtering if metadata_filter: for key, value in metadata_filter.items(): - query += f" AND d.metadata->>'{key}' = ${len(params) + 1}" + query += f" AND metadata->>'{key}' = ${len(params) + 1}" params.append(value) # Add sorting and pagination diff --git a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py index c7b15ee64..5a89803ee 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_embedding.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_embedding.py @@ -1,7 +1,3 @@ -""" -Timescale-based doc embedding search using the `embedding` column. -""" - from typing import List, Literal from uuid import UUID diff --git a/agents-api/agents_api/queries/docs/search_docs_by_text.py b/agents-api/agents_api/queries/docs/search_docs_by_text.py index 0ab309ee8..79f9ac305 100644 --- a/agents-api/agents_api/queries/docs/search_docs_by_text.py +++ b/agents-api/agents_api/queries/docs/search_docs_by_text.py @@ -1,35 +1,36 @@ -""" -Timescale-based doc text search using the `search_tsv` column. -""" - -from typing import Literal +from typing import Any, Literal, List from uuid import UUID from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one +import asyncpg +import json from ...autogen.openapi_model import DocReference -from ..utils import pg_query, wrap_in_class +from ..utils import pg_query, wrap_in_class, rewrap_exceptions, partialclass -search_docs_text_query = parse_one(""" -SELECT d.*, - ts_rank_cd(d.search_tsv, websearch_to_tsquery($3)) AS rank -FROM docs d -LEFT JOIN doc_owners do - ON d.developer_id = do.developer_id - AND d.doc_id = do.doc_id -WHERE d.developer_id = $1 - AND ( - ($4 IS NULL AND $5 IS NULL) - OR (do.owner_type = $4 AND do.owner_id = $5) - ) - AND d.search_tsv @@ websearch_to_tsquery($3) -ORDER BY rank DESC -LIMIT $2; -""").sql(pretty=True) +search_docs_text_query = ( + """ + SELECT * FROM search_by_text( + $1, -- developer_id + $2, -- query + $3, -- owner_types + ( SELECT array_agg(*)::UUID[] FROM jsonb_array_elements($4) ) + ) + """ +) +@rewrap_exceptions( + { + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer does not exist.", + ) + } +) @wrap_in_class( DocReference, transform=lambda d: { @@ -41,15 +42,16 @@ **d, }, ) -@pg_query +@pg_query(debug=True) @beartype async def search_docs_by_text( *, developer_id: UUID, + owners: list[tuple[Literal["user", "agent"], UUID]], query: str, - k: int = 10, - owner_type: Literal["user", "agent", "org"] | None = None, - owner_id: UUID | None = None, + k: int = 3, + metadata_filter: dict[str, Any] = {}, + search_language: str | None = "english", ) -> tuple[str, list]: """ Full-text search on docs using the search_tsv column. @@ -57,9 +59,11 @@ async def search_docs_by_text( Parameters: developer_id (UUID): The ID of the developer. query (str): The text to search for. - k (int): The number of results to return. - owner_type (Literal["user", "agent", "org"]): The type of the owner of the documents. - owner_id (UUID): The ID of the owner of the documents. + owners (list[tuple[Literal["user", "agent"], UUID]]): List of (owner_type, owner_id) tuples. + k (int): Maximum number of results to return. + search_language (str): Language for text search (default: "english"). + metadata_filter (dict): Metadata filter criteria. + connection_pool (asyncpg.Pool): Database connection pool. Returns: tuple[str, list]: SQL query and parameters for searching the documents. @@ -67,7 +71,19 @@ async def search_docs_by_text( if k < 1: raise HTTPException(status_code=400, detail="k must be >= 1") + # Extract owner types and IDs + owner_types = [owner[0] for owner in owners] + owner_ids = [owner[1] for owner in owners] + return ( search_docs_text_query, - [developer_id, k, query, owner_type, owner_id], + [ + developer_id, + query, + owner_types, + owner_ids, + search_language, + k, + metadata_filter, + ], ) diff --git a/agents-api/agents_api/queries/docs/search_docs_hybrid.py b/agents-api/agents_api/queries/docs/search_docs_hybrid.py index a879e3b6b..184ba7e8e 100644 --- a/agents-api/agents_api/queries/docs/search_docs_hybrid.py +++ b/agents-api/agents_api/queries/docs/search_docs_hybrid.py @@ -1,8 +1,3 @@ -""" -Hybrid doc search that merges text search and embedding search results -via a simple distribution-based score fusion or direct weighting in Python. -""" - from typing import List, Literal from uuid import UUID diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 2f7de580e..a34c7e2aa 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -63,23 +63,6 @@ def test_developer_id(): developer_id = uuid7() return developer_id - -# @fixture(scope="global") -# async def test_file(dsn=pg_dsn, developer_id=test_developer_id): -# async with get_pg_client(dsn=dsn) as client: -# file = await create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# client=client, -# ) -# yield file - - @fixture(scope="global") async def test_developer(dsn=pg_dsn, developer_id=test_developer_id): pool = await create_db_pool(dsn=dsn) @@ -150,16 +133,18 @@ async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): @fixture(scope="test") -async def test_doc(dsn=pg_dsn, developer=test_developer): +async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) doc = await create_doc( developer_id=developer.id, data=CreateDocRequest( title="Hello", - content=["World"], + content=["World", "World2", "World3"], metadata={"test": "test"}, embed_instruction="Embed the document", ), + owner_type="agent", + owner_id=agent.id, connection_pool=pool, ) return doc diff --git a/agents-api/tests/test_docs_queries.py b/agents-api/tests/test_docs_queries.py index 1410c88c9..71553ee83 100644 --- a/agents-api/tests/test_docs_queries.py +++ b/agents-api/tests/test_docs_queries.py @@ -8,36 +8,13 @@ from agents_api.queries.docs.list_docs import list_docs # If you wish to test text/embedding/hybrid search, import them: -# from agents_api.queries.docs.search_docs_by_text import search_docs_by_text +from agents_api.queries.docs.search_docs_by_text import search_docs_by_text # from agents_api.queries.docs.search_docs_by_embedding import search_docs_by_embedding # from agents_api.queries.docs.search_docs_hybrid import search_docs_hybrid # You can rename or remove these imports to match your actual fixtures from tests.fixtures import pg_dsn, test_agent, test_developer, test_doc, test_user -@test("query: create doc") -async def _(dsn=pg_dsn, developer=test_developer): - pool = await create_db_pool(dsn=dsn) - doc = await create_doc( - developer_id=developer.id, - data=CreateDocRequest( - title="Hello Doc", - content="This is sample doc content", - embed_instruction="Embed the document", - metadata={"test": "test"}, - ), - connection_pool=pool, - ) - - assert doc.title == "Hello Doc" - assert doc.content == "This is sample doc content" - assert doc.modality == "text" - assert doc.embedding_model == "voyage-3" - assert doc.embedding_dimensions == 1024 - assert doc.language == "english" - assert doc.index == 0 - - @test("query: create user doc") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) @@ -92,7 +69,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert any(d.id == doc.id for d in docs_list) -@test("model: get doc") +@test("query: get doc") async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): pool = await create_db_pool(dsn=dsn) doc_test = await get_doc( @@ -102,18 +79,7 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): ) assert doc_test.id == doc.id assert doc_test.title == doc.title - - -@test("query: list docs") -async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): - pool = await create_db_pool(dsn=dsn) - docs_list = await list_docs( - developer_id=developer.id, - connection_pool=pool, - ) - assert len(docs_list) >= 1 - assert any(d.id == doc.id for d in docs_list) - + assert doc_test.content == doc.content @test("query: list user docs") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): @@ -246,12 +212,34 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): ) assert not any(d.id == doc_agent.id for d in docs_list) - -@test("query: delete doc") -async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc): +@test("query: search docs by text") +async def _(dsn=pg_dsn, agent=test_agent, developer=test_developer): pool = await create_db_pool(dsn=dsn) - await delete_doc( + + # Create a test document + await create_doc( developer_id=developer.id, - doc_id=doc.id, + owner_type="agent", + owner_id=agent.id, + data=CreateDocRequest( + title="Hello", + content="The world is a funny little thing", + metadata={"test": "test"}, + embed_instruction="Embed the document", + ), connection_pool=pool, ) + + # Search using the correct parameter types + result = await search_docs_by_text( + developer_id=developer.id, + owners=[("agent", agent.id)], + query="funny", + k=3, # Add k parameter + search_language="english", # Add language parameter + metadata_filter={}, # Add metadata filter + connection_pool=pool, + ) + + assert len(result) >= 1 + assert result[0].metadata is not None \ No newline at end of file diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index c83c7a6f6..68409ef5c 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -82,7 +82,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): assert any(f.id == file.id for f in files) -@test("model: get file") +@test("query: get file") async def _(dsn=pg_dsn, file=test_file, developer=test_developer): pool = await create_db_pool(dsn=dsn) file_test = await get_file( diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql index 193fae122..97bdad43c 100644 --- a/memory-store/migrations/000006_docs.up.sql +++ b/memory-store/migrations/000006_docs.up.sql @@ -24,8 +24,7 @@ CREATE TABLE IF NOT EXISTS docs ( created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB, - CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id), - CONSTRAINT uq_docs_doc_id_index UNIQUE (doc_id, index), + CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id, index), CONSTRAINT ct_docs_embedding_dimensions_positive CHECK (embedding_dimensions > 0), CONSTRAINT ct_docs_valid_modality CHECK (modality IN ('text', 'image', 'mixed')), CONSTRAINT ct_docs_index_positive CHECK (index >= 0), @@ -67,10 +66,12 @@ END $$; CREATE TABLE IF NOT EXISTS doc_owners ( developer_id UUID NOT NULL, doc_id UUID NOT NULL, + index INTEGER NOT NULL, owner_type TEXT NOT NULL, -- 'user' or 'agent' owner_id UUID NOT NULL, - CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), - CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id, index), + -- TODO: Add foreign key constraint + -- CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); diff --git a/memory-store/migrations/000018_doc_search.up.sql b/memory-store/migrations/000018_doc_search.up.sql index 5293cc81a..2f5b2baf1 100644 --- a/memory-store/migrations/000018_doc_search.up.sql +++ b/memory-store/migrations/000018_doc_search.up.sql @@ -101,6 +101,7 @@ END $$; -- Create the search function CREATE OR REPLACE FUNCTION search_by_vector ( + developer_id UUID, query_embedding vector (1024), owner_types TEXT[], owner_ids UUID [], @@ -134,9 +135,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -153,6 +152,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -160,15 +160,12 @@ BEGIN (1 - (d.embedding <=> $1)) as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE 1 - (d.embedding <=> $1) >= $2 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $7 + AND 1 - (d.embedding <=> $1) >= $2 %s %s ) @@ -185,7 +182,9 @@ BEGIN k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; + END; $$; @@ -238,6 +237,7 @@ COMMENT ON FUNCTION embed_and_search_by_vector IS 'Convenience function that com -- Create the text search function CREATE OR REPLACE FUNCTION search_by_text ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -267,9 +267,7 @@ BEGIN IF owner_types IS NOT NULL AND owner_ids IS NOT NULL THEN owner_filter_sql := ' AND ( - (ud.user_id = ANY($5) AND ''user'' = ANY($4)) - OR - (ad.agent_id = ANY($5) AND ''agent'' = ANY($4)) + doc_owners.owner_id = ANY($5::uuid[]) AND doc_owners.owner_type = ANY($4::text[]) )'; ELSE owner_filter_sql := ''; @@ -286,6 +284,7 @@ BEGIN RETURN QUERY EXECUTE format( 'WITH ranked_docs AS ( SELECT + d.developer_id, d.doc_id, d.index, d.title, @@ -293,15 +292,12 @@ BEGIN ts_rank_cd(d.search_tsv, $1, 32)::double precision as distance, d.embedding, d.metadata, - CASE - WHEN ud.user_id IS NOT NULL THEN ''user'' - WHEN ad.agent_id IS NOT NULL THEN ''agent'' - END as owner_type, - COALESCE(ud.user_id, ad.agent_id) as owner_id + doc_owners.owner_type, + doc_owners.owner_id FROM docs_embeddings d - LEFT JOIN user_docs ud ON d.doc_id = ud.doc_id - LEFT JOIN agent_docs ad ON d.doc_id = ad.doc_id - WHERE d.search_tsv @@ $1 + LEFT JOIN doc_owners ON d.doc_id = doc_owners.doc_id + WHERE d.developer_id = $6 + AND d.search_tsv @@ $1 %s %s ) @@ -314,11 +310,11 @@ BEGIN ) USING ts_query, - search_language, k, owner_types, owner_ids, - metadata_filter; + metadata_filter, + developer_id; END; $$; @@ -372,6 +368,7 @@ $$ LANGUAGE plpgsql; -- Hybrid search function combining text and vector search CREATE OR REPLACE FUNCTION search_hybrid ( + developer_id UUID, query_text text, query_embedding vector (1024), owner_types TEXT[], @@ -397,6 +394,7 @@ BEGIN RETURN QUERY WITH text_results AS ( SELECT * FROM search_by_text( + developer_id, query_text, owner_types, owner_ids, @@ -407,6 +405,7 @@ BEGIN ), embedding_results AS ( SELECT * FROM search_by_vector( + developer_id, query_embedding, owner_types, owner_ids, @@ -426,6 +425,7 @@ BEGIN ), scores AS ( SELECT + r.developer_id, r.doc_id, r.title, r.content, @@ -437,8 +437,8 @@ BEGIN COALESCE(t.distance, 0.0) as text_score, COALESCE(e.distance, 0.0) as embedding_score FROM all_results r - LEFT JOIN text_results t ON r.doc_id = t.doc_id - LEFT JOIN embedding_results e ON r.doc_id = e.doc_id + LEFT JOIN text_results t ON r.doc_id = t.doc_id AND r.developer_id = t.developer_id + LEFT JOIN embedding_results e ON r.doc_id = e.doc_id AND r.developer_id = e.developer_id ), normalized_scores AS ( SELECT @@ -448,6 +448,7 @@ BEGIN FROM scores ) SELECT + developer_id, doc_id, index, title, @@ -468,6 +469,7 @@ COMMENT ON FUNCTION search_hybrid IS 'Hybrid search combining text and vector se -- Convenience function that handles embedding generation CREATE OR REPLACE FUNCTION embed_and_search_hybrid ( + developer_id UUID, query_text text, owner_types TEXT[], owner_ids UUID [], @@ -497,6 +499,7 @@ BEGIN -- Perform hybrid search RETURN QUERY SELECT * FROM search_hybrid( + developer_id, query_text, query_embedding, owner_types,