diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 3527f3611..c0ca3919f 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -17,19 +17,39 @@ # Define the raw SQL query agent_query = parse_one(""" -WITH deleted_docs AS ( +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 + ) +), +deleted_docs AS ( DELETE FROM docs WHERE developer_id = $1 AND doc_id IN ( - SELECT ad.doc_id - FROM agent_docs ad - WHERE ad.agent_id = $2 - AND ad.developer_id = $1 + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 ) -), deleted_agent_docs AS ( - DELETE FROM agent_docs - WHERE agent_id = $2 AND developer_id = $1 -), deleted_tools AS ( +), +deleted_tools AS ( DELETE FROM tools WHERE agent_id = $2 AND developer_id = $1 ) diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 33dcda984..95973ad0b 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -1,14 +1,17 @@ from typing import Literal from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException +from litellm.utils import _select_tokenizer as select_tokenizer from uuid_extensions import uuid7 from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation from ...common.utils.datetime import utcnow from ...common.utils.messages import content_to_json from ...metrics.counters import increase_counter -from ..utils import pg_query, wrap_in_class +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Query for checking if the session exists session_exists_query = """ @@ -22,7 +25,7 @@ entry_query = """ INSERT INTO entries ( session_id, - entry_id, + entry_id, source, role, event_type, @@ -32,9 +35,10 @@ tool_calls, model, token_count, + tokenizer, created_at, timestamp -) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) +) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) RETURNING *; """ @@ -50,34 +54,34 @@ """ -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="Entry already exists", -# ), -# asyncpg.NotNullViolationError: partialclass( -# HTTPException, -# status_code=400, -# detail="Not null violation", -# ), -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="Entry already exists", + ), + asyncpg.NotNullViolationError: partialclass( + HTTPException, + status_code=400, + detail="Not null violation", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class( Entry, transform=lambda d: { - "id": UUID(d.pop("entry_id")), + "id": d.pop("entry_id"), **d, }, ) @@ -89,7 +93,7 @@ async def create_entries( developer_id: UUID, session_id: UUID, data: list[CreateEntryRequest], -) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] @@ -100,7 +104,7 @@ async def create_entries( params.append( [ session_id, # $1 - item.pop("id", None) or str(uuid7()), # $2 + item.pop("id", None) or uuid7(), # $2 item.get("source"), # $3 item.get("role"), # $4 item.get("event_type") or "message.create", # $5 @@ -110,8 +114,9 @@ async def create_entries( content_to_json(item.get("tool_calls") or {}), # $9 item.get("model"), # $10 item.get("token_count"), # $11 - item.get("created_at") or utcnow(), # $12 - utcnow(), # $13 + select_tokenizer(item.get("model"))["type"], # $12 + item.get("created_at") or utcnow(), # $13 + utcnow().timestamp(), # $14 ] ) @@ -119,7 +124,7 @@ async def create_entries( ( session_exists_query, [session_id, developer_id], - "fetch", + "fetchrow", ), ( entry_query, @@ -129,20 +134,25 @@ async def create_entries( ] -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="Entry already exists", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="Entry already exists", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class(Relation) @increase_counter("add_entry_relations") @pg_query @@ -152,7 +162,7 @@ async def add_entry_relations( developer_id: UUID, session_id: UUID, data: list[Relation], -) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: # Convert the data to a list of dictionaries data_dicts = [item.model_dump(mode="json") for item in data] diff --git a/agents-api/agents_api/queries/entries/delete_entries.py b/agents-api/agents_api/queries/entries/delete_entries.py index 628ef9011..47b7379a4 100644 --- a/agents-api/agents_api/queries/entries/delete_entries.py +++ b/agents-api/agents_api/queries/entries/delete_entries.py @@ -1,13 +1,15 @@ from typing import Literal from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import pg_query, wrap_in_class +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query for deleting entries with a developer check delete_entry_query = parse_one(""" @@ -55,20 +57,25 @@ """ -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified session or developer does not exist.", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="The specified session has already been deleted.", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified session or developer does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="The specified session has already been deleted.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class( ResourceDeletedResponse, one=True, @@ -85,29 +92,34 @@ async def delete_entries_for_session( *, developer_id: UUID, session_id: UUID, -) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: """Delete all entries for a given session.""" return [ - (session_exists_query, [session_id, developer_id], "fetch"), + (session_exists_query, [session_id, developer_id], "fetchrow"), (delete_entry_relations_query, [session_id], "fetchmany"), (delete_entry_query, [session_id, developer_id], "fetchmany"), ] -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="The specified entries, session, or developer does not exist.", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="One or more specified entries have already been deleted.", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified entries, session, or developer does not exist.", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="One or more specified entries have already been deleted.", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class( ResourceDeletedResponse, transform=lambda d: { @@ -121,10 +133,18 @@ async def delete_entries_for_session( @beartype async def delete_entries( *, developer_id: UUID, session_id: UUID, entry_ids: list[UUID] -) -> list[tuple[str, list, Literal["fetch", "fetchmany"]]]: +) -> list[tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]]: """Delete specific entries by their IDs.""" return [ - (session_exists_query, [session_id, developer_id], "fetch"), - (delete_entry_relations_by_ids_query, [session_id, entry_ids], "fetchmany"), - (delete_entry_by_ids_query, [entry_ids, developer_id, session_id], "fetchmany"), + ( + session_exists_query, + [session_id, developer_id], + "fetchrow", + ), + (delete_entry_relations_by_ids_query, [session_id, entry_ids], "fetch"), + ( + delete_entry_by_ids_query, + [entry_ids, developer_id, session_id], + "fetch", + ), ] diff --git a/agents-api/agents_api/queries/entries/get_history.py b/agents-api/agents_api/queries/entries/get_history.py index b0b767c08..e6967a6cc 100644 --- a/agents-api/agents_api/queries/entries/get_history.py +++ b/agents-api/agents_api/queries/entries/get_history.py @@ -1,61 +1,91 @@ +import json +from typing import Any, List, Tuple from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import History -from ..utils import pg_query, wrap_in_class +from ...common.utils.datetime import utcnow +from ..utils import ( + partialclass, + pg_query, + rewrap_exceptions, + wrap_in_class, +) -# Define the raw SQL query for getting history with a developer check +# Define the raw SQL query for getting history with a developer check and relations history_query = parse_one(""" +WITH entries AS ( + SELECT + e.entry_id AS id, + e.session_id, + e.role, + e.name, + e.content, + e.source, + e.token_count, + e.created_at, + e.timestamp, + e.tool_calls, + e.tool_call_id, + e.tokenizer + FROM entries e + JOIN developers d ON d.developer_id = $3 + WHERE e.session_id = $1 + AND e.source = ANY($2) +), +relations AS ( + SELECT + er.head, + er.relation, + er.tail + FROM entry_relations er + WHERE er.session_id = $1 +) SELECT - e.entry_id as id, -- entry_id - e.session_id, -- session_id - e.role, -- role - e.name, -- name - e.content, -- content - e.source, -- source - e.token_count, -- token_count - e.created_at, -- created_at - e.timestamp, -- timestamp - e.tool_calls, -- tool_calls - e.tool_call_id -- tool_call_id -FROM entries e -JOIN developers d ON d.developer_id = $3 -WHERE e.session_id = $1 -AND e.source = ANY($2) -ORDER BY e.created_at; + (SELECT json_agg(e) FROM entries e) AS entries, + (SELECT json_agg(r) FROM relations r) AS relations, + $1::uuid AS session_id, """).sql(pretty=True) -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="Entry already exists", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class( History, one=True, transform=lambda d: { - **d, + "entries": json.loads(d.get("entries") or "[]"), "relations": [ { "head": r["head"], "relation": r["relation"], "tail": r["tail"], } - for r in d.pop("relations") + for r in (d.get("relations") or []) ], - "entries": d.pop("entries"), + "session_id": d.get("session_id"), + "created_at": utcnow(), }, ) @pg_query @@ -65,7 +95,7 @@ async def get_history( developer_id: UUID, session_id: UUID, allowed_sources: list[str] = ["api_request", "api_response"], -) -> tuple[str, list]: +) -> tuple[str, list] | tuple[str, list, str]: return ( history_query, [session_id, allowed_sources, developer_id], diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index a6c355f53..89f432734 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -1,12 +1,13 @@ from typing import Literal from uuid import UUID +import asyncpg from beartype import beartype from fastapi import HTTPException from ...autogen.openapi_model import Entry from ...metrics.counters import increase_counter -from ..utils import pg_query, wrap_in_class +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Query for checking if the session exists session_exists_query = """ @@ -34,7 +35,8 @@ e.event_type, e.tool_call_id, e.tool_calls, - e.model + e.model, + e.tokenizer FROM entries e JOIN developers d ON d.developer_id = $5 LEFT JOIN entry_relations er ON er.head = e.entry_id AND er.session_id = e.session_id @@ -47,30 +49,30 @@ """ -# @rewrap_exceptions( -# { -# asyncpg.ForeignKeyViolationError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# asyncpg.UniqueViolationError: partialclass( -# HTTPException, -# status_code=409, -# detail="Entry already exists", -# ), -# asyncpg.NotNullViolationError: partialclass( -# HTTPException, -# status_code=400, -# detail="Entry is required", -# ), -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="Session not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + asyncpg.UniqueViolationError: partialclass( + HTTPException, + status_code=409, + detail="Entry already exists", + ), + asyncpg.NotNullViolationError: partialclass( + HTTPException, + status_code=400, + detail="Entry is required", + ), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="Session not found", + ), + } +) @wrap_in_class(Entry) @increase_counter("list_entries") @pg_query @@ -114,5 +116,6 @@ async def list_entries( ( query, entry_params, + "fetch", ), ] diff --git a/agents-api/agents_api/queries/files/__init__.py b/agents-api/agents_api/queries/files/__init__.py new file mode 100644 index 000000000..99670a8fc --- /dev/null +++ b/agents-api/agents_api/queries/files/__init__.py @@ -0,0 +1,16 @@ +""" +The `files` module within the `queries` package provides SQL query functions for managing files +in the PostgreSQL database. This includes operations for: + +- Creating new files +- Retrieving file details +- Listing files with filtering and pagination +- Deleting files and their associations +""" + +from .create_file import create_file +from .delete_file import delete_file +from .get_file import get_file +from .list_files import list_files + +__all__ = ["create_file", "delete_file", "get_file", "list_files"] diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py new file mode 100644 index 000000000..48251fa5e --- /dev/null +++ b/agents-api/agents_api/queries/files/create_file.py @@ -0,0 +1,144 @@ +""" +This module contains the functionality for creating files in the PostgreSQL database. +It includes functions to construct and execute SQL queries for inserting new file records. +""" + +import base64 +import hashlib +from typing import Any, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one +from uuid_extensions import uuid7 + +from ...autogen.openapi_model import CreateFileRequest, File +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Create file +file_query = parse_one(""" +INSERT INTO files ( + developer_id, + file_id, + name, + description, + mime_type, + size, + hash +) +VALUES ( + $1, -- developer_id + $2, -- file_id + $3, -- name + $4, -- description + $5, -- mime_type + $6, -- size + $7 -- hash +) +RETURNING *; +""").sql(pretty=True) + +# Replace both user_file and agent_file queries with a single file_owner query +file_owner_query = parse_one(""" +WITH inserted_owner AS ( + INSERT INTO file_owners ( + developer_id, + file_id, + owner_type, + owner_id + ) + VALUES ($1, $2, $3, $4) + RETURNING file_id +) +SELECT f.* +FROM inserted_owner io +JOIN files f ON f.file_id = io.file_id; +""").sql(pretty=True) + + +# Add error handling decorator +# @rewrap_exceptions( +# { +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="A file with this name already exists for this developer", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified owner does not exist", +# ), +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist", +# ), +# } +# ) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + **d, + "id": d["file_id"], + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@increase_counter("create_file") +@pg_query +@beartype +async def create_file( + *, + developer_id: UUID, + file_id: UUID | None = None, + data: CreateFileRequest, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs and executes SQL queries to create a new file and optionally associate it with an owner. + + Parameters: + developer_id (UUID): The unique identifier for the developer. + file_id (UUID | None): Optional unique identifier for the file. + data (CreateFileRequest): The file data to insert. + owner_type (Literal["user", "agent"] | None): Optional type of owner + owner_id (UUID | None): Optional ID of the owner + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: List of SQL queries, their parameters, and fetch type + """ + file_id = file_id or uuid7() + + # Calculate size and hash + content_bytes = base64.b64decode(data.content) + size = len(content_bytes) + hash_bytes = hashlib.sha256(content_bytes).digest() + + # Base file parameters + file_params = [ + developer_id, + file_id, + data.name, + data.description, + data.mime_type, + size, + hash_bytes, + ] + + queries = [] + + # Create the file first + queries.append((file_query, file_params)) + + # Then create the association if owner info provided + if owner_type and owner_id: + assoc_params = [developer_id, file_id, owner_type, owner_id] + queries.append((file_owner_query, assoc_params)) + + return queries diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py new file mode 100644 index 000000000..31cb43404 --- /dev/null +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -0,0 +1,87 @@ +""" +This module contains the functionality for deleting files from the PostgreSQL database. +It constructs and executes SQL queries to remove file records and associated data. +""" + +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Delete file query with ownership check +delete_file_query = parse_one(""" +WITH deleted_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (owner_type = $3 AND owner_id = $4) + ) +) +DELETE FROM files +WHERE developer_id = $1 +AND file_id = $2 +AND ($3::text IS NULL OR EXISTS ( + SELECT 1 FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND owner_type = $3 + AND owner_id = $4 +)) +RETURNING file_id; +""").sql(pretty=True) + + +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="File not found", + ), + } +) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["file_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_file") +@pg_query +@beartype +async def delete_file( + *, + developer_id: UUID, + file_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Deletes a file and its ownership records. + + Args: + developer_id: The developer's UUID + file_id: The file's UUID + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + delete_file_query, + [developer_id, file_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py new file mode 100644 index 000000000..4d5dca4c0 --- /dev/null +++ b/agents-api/agents_api/queries/files/get_file.py @@ -0,0 +1,81 @@ +""" +This module contains the functionality for retrieving a single file from the PostgreSQL database. +It constructs and executes SQL queries to fetch file details based on file ID and developer ID. +""" + +from typing import Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import File +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Define the raw SQL query +file_query = parse_one(""" +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 +AND f.file_id = $2 +AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (fo.owner_type = $3 AND fo.owner_id = $4) +) +LIMIT 1; +""").sql(pretty=True) + + +# @rewrap_exceptions( +# { +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="File not found", +# ), +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Developer not found", +# ), +# } +# ) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + "id": d["file_id"], + **d, + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@pg_query +@beartype +async def get_file( + *, + file_id: UUID, + developer_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: + """ + Constructs the SQL query to retrieve a file's details. + Uses composite index on (developer_id, file_id) for efficient lookup. + + Args: + file_id: The UUID of the file to retrieve + developer_id: The UUID of the developer owning the file + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner + + Returns: + tuple[str, list]: SQL query and parameters + """ + return ( + file_query, + [developer_id, file_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py new file mode 100644 index 000000000..2bc42f842 --- /dev/null +++ b/agents-api/agents_api/queries/files/list_files.py @@ -0,0 +1,122 @@ +""" +This module contains the functionality for listing files from the PostgreSQL database. +It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination. +""" + +from typing import Any, Literal +from uuid import UUID + +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import File +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class + +# Query to list all files for a developer (uses developer_id index) +developer_files_query = parse_one(""" +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 +ORDER BY + CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at + END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + +# Query to list files for a specific owner (uses composite indexes) +owner_files_query = parse_one(""" +SELECT f.* +FROM files f +JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE fo.developer_id = $1 +AND fo.owner_id = $6 +AND fo.owner_type = $7 +ORDER BY + CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at + END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + + +@wrap_in_class( + File, + one=False, + transform=lambda d: { + **d, + "id": d["file_id"], + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@pg_query +@beartype +async def list_files( + *, + developer_id: UUID, + owner_id: UUID | None = None, + owner_type: Literal["user", "agent"] | None = None, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + Lists files with optimized queries for two cases: + 1. Owner specified: Returns files associated with that owner + 2. No owner: Returns all files for the developer + + Args: + developer_id: UUID of the developer + owner_id: Optional UUID of the owner (user or agent) + owner_type: Optional type of owner ("user" or "agent") + limit: Maximum number of records to return (1-100) + offset: Number of records to skip + sort_by: Field to sort by + direction: Sort direction ('asc' or 'desc') + + Returns: + Tuple of (query, params) + + Raises: + HTTPException: If parameters are invalid + """ + # Validate parameters + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + # Base parameters used in all queries + params = [ + developer_id, + limit, + offset, + sort_by, + direction, + ] + + # Choose appropriate query based on owner details + if owner_id and owner_type: + params.extend([owner_id, owner_type]) # Add owner_id as $6 and owner_type as $7 + query = owner_files_query # Use single query with owner_type parameter + else: + query = developer_files_query + + return (query, params) diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index 86bcc0b26..ad5befd73 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -11,14 +11,37 @@ # Define the raw SQL query outside the function delete_query = parse_one(""" -WITH deleted_data AS ( - DELETE FROM user_files -- user_files - WHERE developer_id = $1 -- developer_id - AND user_id = $2 -- user_id +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) ), deleted_docs AS ( - DELETE FROM user_docs - WHERE developer_id = $1 AND user_id = $2 + DELETE FROM docs + WHERE developer_id = $1 + AND doc_id IN ( + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) ) DELETE FROM users WHERE developer_id = $1 AND user_id = $2 diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 0c20ca59e..0d139cb91 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -175,9 +175,9 @@ async def wrapper( all_results.append(results) if method_name == "fetchrow" and ( - len(results) == 0 or results.get("bool") is None + len(results) == 0 or results.get("bool", True) is None ): - raise asyncpg.NoDataFoundError + raise asyncpg.NoDataFoundError("No data found") end = timeit and time.perf_counter() diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index e1d286c9c..430a2e3c5 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -8,6 +8,7 @@ from agents_api.autogen.openapi_model import ( CreateAgentRequest, + CreateFileRequest, CreateSessionRequest, CreateUserRequest, ) @@ -24,7 +25,8 @@ # from agents_api.queries.execution.create_execution import create_execution # from agents_api.queries.execution.create_execution_transition import create_execution_transition # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup -# from agents_api.queries.files.create_file import create_file +from agents_api.queries.files.create_file import create_file + # from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session @@ -129,6 +131,23 @@ async def test_user(dsn=pg_dsn, developer=test_developer): return user +@fixture(scope="test") +async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Hello", + description="World", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + connection_pool=pool, + ) + + return file + + @fixture(scope="test") async def random_email(): return f"{"".join([random.choice(string.ascii_lowercase) for _ in range(10)])}@mail.com" diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 85d10f6ea..693f409a6 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -143,12 +143,21 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: delete agent sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully deleted.""" pool = await create_db_pool(dsn=dsn) + create_result = await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + ), + connection_pool=pool, + ) delete_result = await delete_agent( - agent_id=agent.id, developer_id=developer_id, connection_pool=pool + agent_id=create_result.id, developer_id=developer_id, connection_pool=pool ) assert delete_result is not None @@ -157,6 +166,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): with raises(Exception): await get_agent( developer_id=developer_id, - agent_id=agent.id, + agent_id=create_result.id, connection_pool=pool, ) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index f5b9d8d56..706185c7b 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,14 +3,25 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ +from uuid import UUID + from fastapi import HTTPException from uuid_extensions import uuid7 from ward import raises, test -from agents_api.autogen.openapi_model import CreateEntryRequest +from agents_api.autogen.openapi_model import ( + CreateEntryRequest, + Entry, + History, +) from agents_api.clients.pg import create_db_pool -from agents_api.queries.entries import create_entries, list_entries -from tests.fixtures import pg_dsn, test_developer # , test_session +from agents_api.queries.entries import ( + create_entries, + delete_entries, + get_history, + list_entries, +) +from tests.fixtures import pg_dsn, test_developer, test_developer_id, test_session MODEL = "gpt-4o-mini" @@ -52,126 +63,125 @@ async def _(dsn=pg_dsn, developer=test_developer): assert exc_info.raised.status_code == 404 -# @test("query: get entries") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the retrieval of entries from the database.""" - -# pool = await create_db_pool(dsn=dsn) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="api_request", -# content="test entry content", -# ) - -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="test entry content", -# source="internal", -# ) - -# await create_entries( -# developer_id=TEST_DEVELOPER_ID, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) - -# result = await list_entries( -# developer_id=TEST_DEVELOPER_ID, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) - - -# # Assert that only one entry is retrieved, matching the session_id. -# assert len(result) == 1 -# assert isinstance(result[0], Entry) -# assert result is not None - - -# @test("query: get history") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the retrieval of entry history from the database.""" - -# pool = await create_db_pool(dsn=dsn) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="api_request", -# content="test entry content", -# ) - -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="test entry content", -# source="internal", -# ) - -# await create_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) - -# result = await get_history( -# developer_id=developer_id, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) - -# # Assert that entries are retrieved and have valid IDs. -# assert result is not None -# assert isinstance(result, History) -# assert len(result.entries) > 0 -# assert result.entries[0].id - - -# @test("query: delete entries") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the deletion of entries from the database.""" - -# pool = await create_db_pool(dsn=dsn) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="api_request", -# content="test entry content", -# ) - -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="internal entry content", -# source="internal", -# ) - -# created_entries = await create_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) - -# entry_ids = [entry.id for entry in created_entries] - -# await delete_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], -# connection_pool=pool, -# ) - -# result = await list_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) - -# Assert that no entries are retrieved after deletion. -# assert all(id not in [entry.id for entry in result] for id in entry_ids) -# assert len(result) == 0 -# assert result is not None +@test("query: get entries") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the retrieval of entries from the database.""" + + pool = await create_db_pool(dsn=dsn) + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + source="api_request", + content="test entry content", + ) + + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + content="test entry content", + source="internal", + ) + + await create_entries( + developer_id=developer_id, + session_id=session.id, + data=[test_entry, internal_entry], + connection_pool=pool, + ) + + result = await list_entries( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + # Assert that only one entry is retrieved, matching the session_id. + assert len(result) == 1 + assert isinstance(result[0], Entry) + assert result is not None + + +@test("query: get history") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the retrieval of entry history from the database.""" + + pool = await create_db_pool(dsn=dsn) + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + source="api_request", + content="test entry content", + ) + + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + content="test entry content", + source="internal", + ) + + await create_entries( + developer_id=developer_id, + session_id=session.id, + data=[test_entry, internal_entry], + connection_pool=pool, + ) + + result = await get_history( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + # Assert that entries are retrieved and have valid IDs. + assert result is not None + assert isinstance(result, History) + assert len(result.entries) > 0 + assert result.entries[0].id + + +@test("query: delete entries") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test the deletion of entries from the database.""" + + pool = await create_db_pool(dsn=dsn) + test_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + source="api_request", + content="test entry content", + ) + + internal_entry = CreateEntryRequest.from_model_input( + model=MODEL, + role="user", + content="internal entry content", + source="internal", + ) + + created_entries = await create_entries( + developer_id=developer_id, + session_id=session.id, + data=[test_entry, internal_entry], + connection_pool=pool, + ) + + entry_ids = [entry.id for entry in created_entries] + + await delete_entries( + developer_id=developer_id, + session_id=session.id, + entry_ids=entry_ids, + connection_pool=pool, + ) + + result = await list_entries( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + # Assert that no entries are retrieved after deletion. + assert all(id not in [entry.id for entry in result] for id in entry_ids) + assert len(result) == 0 + assert result is not None diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 367fcccd4..92b52d733 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,57 +1,253 @@ # # Tests for entry queries -# from ward import test - -# from agents_api.autogen.openapi_model import CreateFileRequest -# from agents_api.queries.files.create_file import create_file -# from agents_api.queries.files.delete_file import delete_file -# from agents_api.queries.files.get_file import get_file -# from tests.fixtures import ( -# cozo_client, -# test_developer_id, -# test_file, -# ) - - -# @test("query: create file") -# def _(client=cozo_client, developer_id=test_developer_id): -# create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# client=client, -# ) - - -# @test("query: get file") -# def _(client=cozo_client, file=test_file, developer_id=test_developer_id): -# get_file( -# developer_id=developer_id, -# file_id=file.id, -# client=client, -# ) - - -# @test("query: delete file") -# def _(client=cozo_client, developer_id=test_developer_id): -# file = create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# client=client, -# ) - -# delete_file( -# developer_id=developer_id, -# file_id=file.id, -# client=client, -# ) +from fastapi import HTTPException +from uuid_extensions import uuid7 +from ward import raises, test + +from agents_api.autogen.openapi_model import CreateFileRequest +from agents_api.clients.pg import create_db_pool +from agents_api.queries.files.create_file import create_file +from agents_api.queries.files.delete_file import delete_file +from agents_api.queries.files.get_file import get_file +from agents_api.queries.files.list_files import list_files +from tests.fixtures import pg_dsn, test_agent, test_developer, test_file, test_user + + +@test("query: create file") +async def _(dsn=pg_dsn, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Hello", + description="World", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + connection_pool=pool, + ) + + +@test("query: create user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User File", + description="Test user file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert file.name == "User File" + + # Verify file appears in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + + +@test("query: create agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent File", + description="Test agent file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert file.name == "Agent File" + + # Verify file appears in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + + +@test("model: get file") +async def _(dsn=pg_dsn, file=test_file, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + file_test = await get_file( + developer_id=developer.id, + file_id=file.id, + connection_pool=pool, + ) + assert file_test.id == file.id + assert file_test.name == "Hello" + assert file_test.description == "World" + assert file_test.mime_type == "text/plain" + assert file_test.hash == file.hash + + +@test("query: list files") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) + files = await list_files( + developer_id=developer.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: list user files") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User List Test", + description="Test file for user listing", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: list agent files") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent List Test", + description="Test file for agent listing", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # List agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: delete user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User Delete Test", + description="Test file for user deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify file is no longer in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent Delete Test", + description="Test file for agent deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify file is no longer in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete file") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) + + await delete_file( + developer_id=developer.id, + file_id=file.id, + connection_pool=pool, + ) diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 7926a391f..4673d6fc5 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -1,7 +1,7 @@ -""" -This module contains tests for SQL query generation functions in the sessions module. -Tests verify the SQL queries without actually executing them against a database. -""" +# """ +# This module contains tests for SQL query generation functions in the sessions module. +# Tests verify the SQL queries without actually executing them against a database. +# """ from uuid_extensions import uuid7 from ward import raises, test diff --git a/memory-store/migrations/000005_files.down.sql b/memory-store/migrations/000005_files.down.sql index 80bf6fecd..c582f7b67 100644 --- a/memory-store/migrations/000005_files.down.sql +++ b/memory-store/migrations/000005_files.down.sql @@ -1,14 +1,12 @@ BEGIN; --- Drop agent_files table and its dependencies -DROP TABLE IF EXISTS agent_files; - --- Drop user_files table and its dependencies -DROP TABLE IF EXISTS user_files; +-- Drop file_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_file_owner ON file_owners; +DROP FUNCTION IF EXISTS validate_file_owner(); +DROP TABLE IF EXISTS file_owners; -- Drop files table and its dependencies DROP TRIGGER IF EXISTS trg_files_updated_at ON files; - DROP TABLE IF EXISTS files; COMMIT; diff --git a/memory-store/migrations/000005_files.up.sql b/memory-store/migrations/000005_files.up.sql index ef4c22b3d..40a2cbccf 100644 --- a/memory-store/migrations/000005_files.up.sql +++ b/memory-store/migrations/000005_files.up.sql @@ -56,30 +56,48 @@ DO $$ BEGIN END IF; END $$; --- Create the user_files table -CREATE TABLE IF NOT EXISTS user_files ( +-- Create the file_owners table +CREATE TABLE IF NOT EXISTS file_owners ( developer_id UUID NOT NULL, - user_id UUID NOT NULL, file_id UUID NOT NULL, - CONSTRAINT pk_user_files PRIMARY KEY (developer_id, user_id, file_id), - CONSTRAINT fk_user_files_user FOREIGN KEY (developer_id, user_id) REFERENCES users (developer_id, user_id), - CONSTRAINT fk_user_files_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id) + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_file_owners PRIMARY KEY (developer_id, file_id), + CONSTRAINT fk_file_owners_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id), + CONSTRAINT ct_file_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); --- Create index if it doesn't exist -CREATE INDEX IF NOT EXISTS idx_user_files_user ON user_files (developer_id, user_id); +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_file_owners_owner + ON file_owners (developer_id, owner_type, owner_id); --- Create the agent_files table -CREATE TABLE IF NOT EXISTS agent_files ( - developer_id UUID NOT NULL, - agent_id UUID NOT NULL, - file_id UUID NOT NULL, - CONSTRAINT pk_agent_files PRIMARY KEY (developer_id, agent_id, file_id), - CONSTRAINT fk_agent_files_agent FOREIGN KEY (developer_id, agent_id) REFERENCES agents (developer_id, agent_id), - CONSTRAINT fk_agent_files_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id) -); +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_file_owner() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; --- Create index if it doesn't exist -CREATE INDEX IF NOT EXISTS idx_agent_files_agent ON agent_files (developer_id, agent_id); +-- Create trigger for validation +CREATE TRIGGER trg_validate_file_owner +BEFORE INSERT OR UPDATE ON file_owners +FOR EACH ROW +EXECUTE FUNCTION validate_file_owner(); COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000006_docs.down.sql b/memory-store/migrations/000006_docs.down.sql index 468b1b483..ea67b0005 100644 --- a/memory-store/migrations/000006_docs.down.sql +++ b/memory-store/migrations/000006_docs.down.sql @@ -1,41 +1,27 @@ BEGIN; +-- Drop doc_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_doc_owner ON doc_owners; +DROP FUNCTION IF EXISTS validate_doc_owner(); +DROP TABLE IF EXISTS doc_owners; + +-- Drop docs table and its dependencies +DROP TRIGGER IF EXISTS trg_docs_search_tsv ON docs; +DROP TRIGGER IF EXISTS trg_docs_updated_at ON docs; +DROP FUNCTION IF EXISTS docs_update_search_tsv(); + -- Drop indexes DROP INDEX IF EXISTS idx_docs_content_trgm; - DROP INDEX IF EXISTS idx_docs_title_trgm; - DROP INDEX IF EXISTS idx_docs_search_tsv; - DROP INDEX IF EXISTS idx_docs_metadata; - -DROP INDEX IF EXISTS idx_agent_docs_agent; - -DROP INDEX IF EXISTS idx_user_docs_user; - DROP INDEX IF EXISTS idx_docs_developer; - DROP INDEX IF EXISTS idx_docs_id_sorted; --- Drop triggers -DROP TRIGGER IF EXISTS trg_docs_search_tsv ON docs; - -DROP TRIGGER IF EXISTS trg_docs_updated_at ON docs; - --- Drop the constraint that depends on is_valid_language function -ALTER TABLE IF EXISTS docs -DROP CONSTRAINT IF EXISTS ct_docs_valid_language; - --- Drop functions -DROP FUNCTION IF EXISTS docs_update_search_tsv (); - -DROP FUNCTION IF EXISTS is_valid_language (text); - --- Drop tables (in correct order due to foreign key constraints) -DROP TABLE IF EXISTS agent_docs; - -DROP TABLE IF EXISTS user_docs; - +-- Drop docs table DROP TABLE IF EXISTS docs; +-- Drop language validation function +DROP FUNCTION IF EXISTS is_valid_language(text); + COMMIT; diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql index 5b532bbef..193fae122 100644 --- a/memory-store/migrations/000006_docs.up.sql +++ b/memory-store/migrations/000006_docs.up.sql @@ -63,31 +63,51 @@ BEGIN END IF; END $$; --- Create the user_docs table -CREATE TABLE IF NOT EXISTS user_docs ( +-- Create the doc_owners table +CREATE TABLE IF NOT EXISTS doc_owners ( developer_id UUID NOT NULL, - user_id UUID NOT NULL, doc_id UUID NOT NULL, - CONSTRAINT pk_user_docs PRIMARY KEY (developer_id, user_id, doc_id), - CONSTRAINT fk_user_docs_user FOREIGN KEY (developer_id, user_id) REFERENCES users (developer_id, user_id), - CONSTRAINT fk_user_docs_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id) + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), + CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); --- Create the agent_docs table -CREATE TABLE IF NOT EXISTS agent_docs ( - developer_id UUID NOT NULL, - agent_id UUID NOT NULL, - doc_id UUID NOT NULL, - CONSTRAINT pk_agent_docs PRIMARY KEY (developer_id, agent_id, doc_id), - CONSTRAINT fk_agent_docs_agent FOREIGN KEY (developer_id, agent_id) REFERENCES agents (developer_id, agent_id), - CONSTRAINT fk_agent_docs_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id) -); +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_doc_owners_owner + ON doc_owners (developer_id, owner_type, owner_id); --- Create indexes if not exists -CREATE INDEX IF NOT EXISTS idx_user_docs_user ON user_docs (developer_id, user_id); +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_doc_owner() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; -CREATE INDEX IF NOT EXISTS idx_agent_docs_agent ON agent_docs (developer_id, agent_id); +-- Create trigger for validation +CREATE TRIGGER trg_validate_doc_owner +BEFORE INSERT OR UPDATE ON doc_owners +FOR EACH ROW +EXECUTE FUNCTION validate_doc_owner(); +-- Create indexes if not exists CREATE INDEX IF NOT EXISTS idx_docs_metadata ON docs USING GIN (metadata); -- Enable necessary PostgreSQL extensions diff --git a/memory-store/migrations/000015_entries.up.sql b/memory-store/migrations/000015_entries.up.sql index c104091a2..73723a8bc 100644 --- a/memory-store/migrations/000015_entries.up.sql +++ b/memory-store/migrations/000015_entries.up.sql @@ -16,8 +16,9 @@ CREATE TABLE IF NOT EXISTS entries ( tool_calls JSONB[] NOT NULL DEFAULT '{}', model TEXT NOT NULL, token_count INTEGER DEFAULT NULL, + tokenizer TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + timestamp DOUBLE PRECISION NOT NULL, CONSTRAINT pk_entries PRIMARY KEY (session_id, entry_id, created_at) ); @@ -58,10 +59,10 @@ END $$; CREATE OR REPLACE FUNCTION optimized_update_token_count_after () RETURNS TRIGGER AS $$ DECLARE - token_count INTEGER; + calc_token_count INTEGER; BEGIN -- Compute token_count outside the UPDATE statement for clarity and potential optimization - token_count := cardinality( + calc_token_count := cardinality( ai.openai_tokenize( 'gpt-4o', -- FIXME: Use `NEW.model` array_to_string(NEW.content::TEXT[], ' ') @@ -69,9 +70,9 @@ BEGIN ); -- Perform the update only if token_count differs - IF token_count <> NEW.token_count THEN + IF calc_token_count <> NEW.token_count THEN UPDATE entries - SET token_count = token_count + SET token_count = calc_token_count WHERE entry_id = NEW.entry_id; END IF;