From 30b57633aafd9b5152fe88cfe104ba60c03fe6bc Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Thu, 19 Dec 2024 10:11:48 +0530 Subject: [PATCH 1/7] fix(memory-store): Change association structure of files and docs Signed-off-by: Diwank Singh Tomer --- memory-store/migrations/000005_files.down.sql | 10 ++-- memory-store/migrations/000005_files.up.sql | 56 ++++++++++++------- memory-store/migrations/000006_docs.down.sql | 42 +++++--------- memory-store/migrations/000006_docs.up.sql | 56 +++++++++++++------ 4 files changed, 93 insertions(+), 71 deletions(-) diff --git a/memory-store/migrations/000005_files.down.sql b/memory-store/migrations/000005_files.down.sql index 80bf6fecd..c582f7b67 100644 --- a/memory-store/migrations/000005_files.down.sql +++ b/memory-store/migrations/000005_files.down.sql @@ -1,14 +1,12 @@ BEGIN; --- Drop agent_files table and its dependencies -DROP TABLE IF EXISTS agent_files; - --- Drop user_files table and its dependencies -DROP TABLE IF EXISTS user_files; +-- Drop file_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_file_owner ON file_owners; +DROP FUNCTION IF EXISTS validate_file_owner(); +DROP TABLE IF EXISTS file_owners; -- Drop files table and its dependencies DROP TRIGGER IF EXISTS trg_files_updated_at ON files; - DROP TABLE IF EXISTS files; COMMIT; diff --git a/memory-store/migrations/000005_files.up.sql b/memory-store/migrations/000005_files.up.sql index ef4c22b3d..40a2cbccf 100644 --- a/memory-store/migrations/000005_files.up.sql +++ b/memory-store/migrations/000005_files.up.sql @@ -56,30 +56,48 @@ DO $$ BEGIN END IF; END $$; --- Create the user_files table -CREATE TABLE IF NOT EXISTS user_files ( +-- Create the file_owners table +CREATE TABLE IF NOT EXISTS file_owners ( developer_id UUID NOT NULL, - user_id UUID NOT NULL, file_id UUID NOT NULL, - CONSTRAINT pk_user_files PRIMARY KEY (developer_id, user_id, file_id), - CONSTRAINT fk_user_files_user FOREIGN KEY (developer_id, user_id) REFERENCES users (developer_id, user_id), - CONSTRAINT fk_user_files_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id) + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_file_owners PRIMARY KEY (developer_id, file_id), + CONSTRAINT fk_file_owners_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id), + CONSTRAINT ct_file_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); --- Create index if it doesn't exist -CREATE INDEX IF NOT EXISTS idx_user_files_user ON user_files (developer_id, user_id); +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_file_owners_owner + ON file_owners (developer_id, owner_type, owner_id); --- Create the agent_files table -CREATE TABLE IF NOT EXISTS agent_files ( - developer_id UUID NOT NULL, - agent_id UUID NOT NULL, - file_id UUID NOT NULL, - CONSTRAINT pk_agent_files PRIMARY KEY (developer_id, agent_id, file_id), - CONSTRAINT fk_agent_files_agent FOREIGN KEY (developer_id, agent_id) REFERENCES agents (developer_id, agent_id), - CONSTRAINT fk_agent_files_file FOREIGN KEY (developer_id, file_id) REFERENCES files (developer_id, file_id) -); +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_file_owner() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; --- Create index if it doesn't exist -CREATE INDEX IF NOT EXISTS idx_agent_files_agent ON agent_files (developer_id, agent_id); +-- Create trigger for validation +CREATE TRIGGER trg_validate_file_owner +BEFORE INSERT OR UPDATE ON file_owners +FOR EACH ROW +EXECUTE FUNCTION validate_file_owner(); COMMIT; \ No newline at end of file diff --git a/memory-store/migrations/000006_docs.down.sql b/memory-store/migrations/000006_docs.down.sql index 468b1b483..ea67b0005 100644 --- a/memory-store/migrations/000006_docs.down.sql +++ b/memory-store/migrations/000006_docs.down.sql @@ -1,41 +1,27 @@ BEGIN; +-- Drop doc_owners table and its dependencies +DROP TRIGGER IF EXISTS trg_validate_doc_owner ON doc_owners; +DROP FUNCTION IF EXISTS validate_doc_owner(); +DROP TABLE IF EXISTS doc_owners; + +-- Drop docs table and its dependencies +DROP TRIGGER IF EXISTS trg_docs_search_tsv ON docs; +DROP TRIGGER IF EXISTS trg_docs_updated_at ON docs; +DROP FUNCTION IF EXISTS docs_update_search_tsv(); + -- Drop indexes DROP INDEX IF EXISTS idx_docs_content_trgm; - DROP INDEX IF EXISTS idx_docs_title_trgm; - DROP INDEX IF EXISTS idx_docs_search_tsv; - DROP INDEX IF EXISTS idx_docs_metadata; - -DROP INDEX IF EXISTS idx_agent_docs_agent; - -DROP INDEX IF EXISTS idx_user_docs_user; - DROP INDEX IF EXISTS idx_docs_developer; - DROP INDEX IF EXISTS idx_docs_id_sorted; --- Drop triggers -DROP TRIGGER IF EXISTS trg_docs_search_tsv ON docs; - -DROP TRIGGER IF EXISTS trg_docs_updated_at ON docs; - --- Drop the constraint that depends on is_valid_language function -ALTER TABLE IF EXISTS docs -DROP CONSTRAINT IF EXISTS ct_docs_valid_language; - --- Drop functions -DROP FUNCTION IF EXISTS docs_update_search_tsv (); - -DROP FUNCTION IF EXISTS is_valid_language (text); - --- Drop tables (in correct order due to foreign key constraints) -DROP TABLE IF EXISTS agent_docs; - -DROP TABLE IF EXISTS user_docs; - +-- Drop docs table DROP TABLE IF EXISTS docs; +-- Drop language validation function +DROP FUNCTION IF EXISTS is_valid_language(text); + COMMIT; diff --git a/memory-store/migrations/000006_docs.up.sql b/memory-store/migrations/000006_docs.up.sql index 5b532bbef..193fae122 100644 --- a/memory-store/migrations/000006_docs.up.sql +++ b/memory-store/migrations/000006_docs.up.sql @@ -63,31 +63,51 @@ BEGIN END IF; END $$; --- Create the user_docs table -CREATE TABLE IF NOT EXISTS user_docs ( +-- Create the doc_owners table +CREATE TABLE IF NOT EXISTS doc_owners ( developer_id UUID NOT NULL, - user_id UUID NOT NULL, doc_id UUID NOT NULL, - CONSTRAINT pk_user_docs PRIMARY KEY (developer_id, user_id, doc_id), - CONSTRAINT fk_user_docs_user FOREIGN KEY (developer_id, user_id) REFERENCES users (developer_id, user_id), - CONSTRAINT fk_user_docs_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id) + owner_type TEXT NOT NULL, -- 'user' or 'agent' + owner_id UUID NOT NULL, + CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id), + CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id), + CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent')) ); --- Create the agent_docs table -CREATE TABLE IF NOT EXISTS agent_docs ( - developer_id UUID NOT NULL, - agent_id UUID NOT NULL, - doc_id UUID NOT NULL, - CONSTRAINT pk_agent_docs PRIMARY KEY (developer_id, agent_id, doc_id), - CONSTRAINT fk_agent_docs_agent FOREIGN KEY (developer_id, agent_id) REFERENCES agents (developer_id, agent_id), - CONSTRAINT fk_agent_docs_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id) -); +-- Create indexes +CREATE INDEX IF NOT EXISTS idx_doc_owners_owner + ON doc_owners (developer_id, owner_type, owner_id); --- Create indexes if not exists -CREATE INDEX IF NOT EXISTS idx_user_docs_user ON user_docs (developer_id, user_id); +-- Create function to validate owner reference +CREATE OR REPLACE FUNCTION validate_doc_owner() +RETURNS TRIGGER AS $$ +BEGIN + IF NEW.owner_type = 'user' THEN + IF NOT EXISTS ( + SELECT 1 FROM users + WHERE developer_id = NEW.developer_id AND user_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid user reference'; + END IF; + ELSIF NEW.owner_type = 'agent' THEN + IF NOT EXISTS ( + SELECT 1 FROM agents + WHERE developer_id = NEW.developer_id AND agent_id = NEW.owner_id + ) THEN + RAISE EXCEPTION 'Invalid agent reference'; + END IF; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; -CREATE INDEX IF NOT EXISTS idx_agent_docs_agent ON agent_docs (developer_id, agent_id); +-- Create trigger for validation +CREATE TRIGGER trg_validate_doc_owner +BEFORE INSERT OR UPDATE ON doc_owners +FOR EACH ROW +EXECUTE FUNCTION validate_doc_owner(); +-- Create indexes if not exists CREATE INDEX IF NOT EXISTS idx_docs_metadata ON docs USING GIN (metadata); -- Enable necessary PostgreSQL extensions From 116edf8d3c57558ea57409521996f018b163712a Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Wed, 18 Dec 2024 23:43:07 -0500 Subject: [PATCH 2/7] wip(agents-api): Add file sql queries --- .../agents_api/queries/files/__init__.py | 21 +++ .../agents_api/queries/files/create_file.py | 150 ++++++++++++++++ .../agents_api/queries/files/delete_file.py | 118 +++++++++++++ .../agents_api/queries/files/get_file.py | 69 ++++++++ .../agents_api/queries/files/list_files.py | 161 ++++++++++++++++++ agents-api/tests/test_files_queries.py | 73 +++++--- 6 files changed, 567 insertions(+), 25 deletions(-) create mode 100644 agents-api/agents_api/queries/files/__init__.py create mode 100644 agents-api/agents_api/queries/files/create_file.py create mode 100644 agents-api/agents_api/queries/files/delete_file.py create mode 100644 agents-api/agents_api/queries/files/get_file.py create mode 100644 agents-api/agents_api/queries/files/list_files.py diff --git a/agents-api/agents_api/queries/files/__init__.py b/agents-api/agents_api/queries/files/__init__.py new file mode 100644 index 000000000..1da09114a --- /dev/null +++ b/agents-api/agents_api/queries/files/__init__.py @@ -0,0 +1,21 @@ +""" +The `files` module within the `queries` package provides SQL query functions for managing files +in the PostgreSQL database. This includes operations for: + +- Creating new files +- Retrieving file details +- Listing files with filtering and pagination +- Deleting files and their associations +""" + +from .create_file import create_file +from .delete_file import delete_file +from .get_file import get_file +from .list_files import list_files + +__all__ = [ + "create_file", + "delete_file", + "get_file", + "list_files" +] \ No newline at end of file diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py new file mode 100644 index 000000000..77e065433 --- /dev/null +++ b/agents-api/agents_api/queries/files/create_file.py @@ -0,0 +1,150 @@ +""" +This module contains the functionality for creating files in the PostgreSQL database. +It includes functions to construct and execute SQL queries for inserting new file records. +""" + +from typing import Any, Literal +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one +from uuid_extensions import uuid7 +import asyncpg +from fastapi import HTTPException +import base64 +import hashlib + +from ...autogen.openapi_model import CreateFileRequest, File +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass + +# Create file +file_query = parse_one(""" +INSERT INTO files ( + developer_id, + file_id, + name, + description, + mime_type, + size, + hash, +) +VALUES ( + $1, -- developer_id + $2, -- file_id + $3, -- name + $4, -- description + $5, -- mime_type + $6, -- size + $7, -- hash +) +RETURNING *; +""").sql(pretty=True) + +# Create user file association +user_file_query = parse_one(""" +INSERT INTO user_files ( + developer_id, + user_id, + file_id +) +VALUES ($1, $2, $3) +ON CONFLICT (developer_id, user_id, file_id) DO NOTHING; -- Uses primary key index +""").sql(pretty=True) + +# Create agent file association +agent_file_query = parse_one(""" +INSERT INTO agent_files ( + developer_id, + agent_id, + file_id +) +VALUES ($1, $2, $3) +ON CONFLICT (developer_id, agent_id, file_id) DO NOTHING; -- Uses primary key index +""").sql(pretty=True) + +# Add error handling decorator +# @rewrap_exceptions( +# { +# asyncpg.UniqueViolationError: partialclass( +# HTTPException, +# status_code=409, +# detail="A file with this name already exists for this developer", +# ), +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified owner does not exist", +# ), +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="The specified developer does not exist", +# ), +# } +# ) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + **d, + "id": d["file_id"], + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@increase_counter("create_file") +@pg_query +@beartype +async def create_file( + *, + developer_id: UUID, + file_id: UUID | None = None, + data: CreateFileRequest, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Constructs and executes SQL queries to create a new file and optionally associate it with an owner. + + Parameters: + developer_id (UUID): The unique identifier for the developer. + file_id (UUID | None): Optional unique identifier for the file. + data (CreateFileRequest): The file data to insert. + owner_type (Literal["user", "agent"] | None): Optional type of owner + owner_id (UUID | None): Optional ID of the owner + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: List of SQL queries, their parameters, and fetch type + """ + file_id = file_id or uuid7() + + # Calculate size and hash + content_bytes = base64.b64decode(data.content) + data.size = len(content_bytes) + data.hash = hashlib.sha256(content_bytes).digest() + + # Base file parameters + file_params = [ + developer_id, + file_id, + data.name, + data.description, + data.mime_type, + data.size, + data.hash, + ] + + queries = [] + + # Create the file + queries.append((file_query, file_params)) + + # Create the association only if both owner_type and owner_id are provided + if owner_type and owner_id: + assoc_params = [developer_id, owner_id, file_id] + if owner_type == "user": + queries.append((user_file_query, assoc_params)) + else: # agent + queries.append((agent_file_query, assoc_params)) + + return queries \ No newline at end of file diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py new file mode 100644 index 000000000..d37e6f3e8 --- /dev/null +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -0,0 +1,118 @@ +""" +This module contains the functionality for deleting files from the PostgreSQL database. +It constructs and executes SQL queries to remove file records and associated data. +""" + +from typing import Literal +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one +import asyncpg +from fastapi import HTTPException + +from ...autogen.openapi_model import ResourceDeletedResponse +from ...common.utils.datetime import utcnow +from ...metrics.counters import increase_counter +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass + +# Simple query to delete file (when no associations exist) +delete_file_query = parse_one(""" +DELETE FROM files +WHERE developer_id = $1 +AND file_id = $2 +AND NOT EXISTS ( + SELECT 1 + FROM user_files uf + WHERE uf.file_id = $2 + LIMIT 1 +) +AND NOT EXISTS ( + SELECT 1 + FROM agent_files af + WHERE af.file_id = $2 + LIMIT 1 +) +RETURNING file_id; +""").sql(pretty=True) + +# Query to delete owner's association +delete_user_assoc_query = parse_one(""" +DELETE FROM user_files +WHERE developer_id = $1 +AND file_id = $2 +AND user_id = $3 +RETURNING file_id; +""").sql(pretty=True) + +delete_agent_assoc_query = parse_one(""" +DELETE FROM agent_files +WHERE developer_id = $1 +AND file_id = $2 +AND agent_id = $3 +RETURNING file_id; +""").sql(pretty=True) + + +# @rewrap_exceptions( +# { +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="File not found", +# ), +# } +# ) +@wrap_in_class( + ResourceDeletedResponse, + one=True, + transform=lambda d: { + "id": d["file_id"], + "deleted_at": utcnow(), + "jobs": [], + }, +) +@increase_counter("delete_file") +@pg_query +@beartype +async def delete_file( + *, + file_id: UUID, + developer_id: UUID, + owner_id: UUID | None = None, + owner_type: Literal["user", "agent"] | None = None, +) -> list[tuple[str, list] | tuple[str, list, str]]: + """ + Deletes a file and/or its association using simple, efficient queries. + + If owner details provided: + 1. Deletes the owner's association + 2. Checks for remaining associations + 3. Deletes file if no associations remain + If no owner details: + - Deletes file only if it has no associations + + Args: + file_id (UUID): The UUID of the file to be deleted. + developer_id (UUID): The UUID of the developer owning the file. + owner_id (UUID | None): Optional owner ID to verify ownership + owner_type (str | None): Optional owner type to verify ownership + + Returns: + list[tuple[str, list] | tuple[str, list, str]]: List of SQL queries, their parameters, and fetch type + """ + queries = [] + + if owner_id and owner_type: + # Delete specific association + assoc_params = [developer_id, file_id, owner_id] + assoc_query = delete_user_assoc_query if owner_type == "user" else delete_agent_assoc_query + queries.append((assoc_query, assoc_params)) + + # If no associations, delete file + queries.append((delete_file_query, [developer_id, file_id])) + else: + # Try to delete file if it has no associations + queries.append((delete_file_query, [developer_id, file_id])) + + return queries diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py new file mode 100644 index 000000000..3143b8ff0 --- /dev/null +++ b/agents-api/agents_api/queries/files/get_file.py @@ -0,0 +1,69 @@ +""" +This module contains the functionality for retrieving a single file from the PostgreSQL database. +It constructs and executes SQL queries to fetch file details based on file ID and developer ID. +""" + +from uuid import UUID + +from beartype import beartype +from sqlglot import parse_one +from fastapi import HTTPException +import asyncpg + +from ...autogen.openapi_model import File +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass + +# Define the raw SQL query +file_query = parse_one(""" +SELECT + file_id, -- Only select needed columns + developer_id, + name, + description, + mime_type, + size, + hash, + created_at, + updated_at +FROM files +WHERE developer_id = $1 -- Order matches composite index (developer_id, file_id) + AND file_id = $2 -- Using both parts of the index +LIMIT 1; -- Early termination once found +""").sql(pretty=True) + +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="File not found", + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="Developer not found", + ), + } +) +@wrap_in_class(File, one=True, transform=lambda d: {"id": d["file_id"], **d}) +@pg_query +@beartype +async def get_file(*, file_id: UUID, developer_id: UUID) -> tuple[str, list]: + """ + Constructs the SQL query to retrieve a file's details. + Uses composite index on (developer_id, file_id) for efficient lookup. + + Args: + file_id (UUID): The UUID of the file to retrieve. + developer_id (UUID): The UUID of the developer owning the file. + + Returns: + tuple[str, list]: A tuple containing the SQL query and its parameters. + + Raises: + HTTPException: If file or developer not found (404) + """ + return ( + file_query, + [developer_id, file_id], # Order matches index columns + ) \ No newline at end of file diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py new file mode 100644 index 000000000..a01f74214 --- /dev/null +++ b/agents-api/agents_api/queries/files/list_files.py @@ -0,0 +1,161 @@ +""" +This module contains the functionality for listing files from the PostgreSQL database. +It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination. +""" + +from typing import Any, Literal +from uuid import UUID +import asyncpg +from beartype import beartype +from fastapi import HTTPException +from sqlglot import parse_one + +from ...autogen.openapi_model import File +from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass + +# Query to list all files for a developer (uses developer_id index) +developer_files_query = parse_one(""" +SELECT + file_id, + developer_id, + name, + description, + mime_type, + size, + hash, + created_at, + updated_at +FROM files +WHERE developer_id = $1 +ORDER BY + CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at + END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + +# Query to list files for a specific user (uses composite indexes) +user_files_query = parse_one(""" +SELECT + f.file_id, + f.developer_id, + f.name, + f.description, + f.mime_type, + f.size, + f.hash, + f.created_at, + f.updated_at +FROM user_files uf +JOIN files f USING (developer_id, file_id) +WHERE uf.developer_id = $1 +AND uf.user_id = $6 +ORDER BY + CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN f.created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN f.created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN f.updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN f.updated_at + END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + +# Query to list files for a specific agent (uses composite indexes) +agent_files_query = parse_one(""" +SELECT + f.file_id, + f.developer_id, + f.name, + f.description, + f.mime_type, + f.size, + f.hash, + f.created_at, + f.updated_at +FROM agent_files af +JOIN files f USING (developer_id, file_id) +WHERE af.developer_id = $1 +AND af.agent_id = $6 +ORDER BY + CASE + WHEN $4 = 'created_at' AND $5 = 'asc' THEN f.created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN f.created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN f.updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN f.updated_at + END DESC NULLS LAST +LIMIT $2 +OFFSET $3; +""").sql(pretty=True) + +@wrap_in_class( + File, + one=True, + transform=lambda d: { + **d, + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", + }, +) +@pg_query +@beartype +async def list_files( + *, + developer_id: UUID, + owner_id: UUID | None = None, + owner_type: Literal["user", "agent"] | None = None, + limit: int = 100, + offset: int = 0, + sort_by: Literal["created_at", "updated_at"] = "created_at", + direction: Literal["asc", "desc"] = "desc", +) -> tuple[str, list]: + """ + Lists files with optimized queries for two cases: + 1. Owner specified: Returns files associated with that owner + 2. No owner: Returns all files for the developer + + Args: + developer_id: UUID of the developer + owner_id: Optional UUID of the owner (user or agent) + owner_type: Optional type of owner ("user" or "agent") + limit: Maximum number of records to return (1-100) + offset: Number of records to skip + sort_by: Field to sort by + direction: Sort direction ('asc' or 'desc') + + Returns: + Tuple of (query, params) + + Raises: + HTTPException: If parameters are invalid + """ + # Validate parameters + if direction.lower() not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Invalid sort direction") + + if limit > 100 or limit < 1: + raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") + + if offset < 0: + raise HTTPException(status_code=400, detail="Offset must be non-negative") + + # Base parameters used in all queries + params = [ + developer_id, + limit, + offset, + sort_by, + direction, + ] + + # Choose appropriate query based on owner details + if owner_id and owner_type: + params.append(owner_id) # Add owner_id as $6 + query = user_files_query if owner_type == "user" else agent_files_query + else: + query = developer_files_query + + return (query, params) diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 367fcccd4..5565d4059 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,22 +1,36 @@ # # Tests for entry queries -# from ward import test - -# from agents_api.autogen.openapi_model import CreateFileRequest -# from agents_api.queries.files.create_file import create_file -# from agents_api.queries.files.delete_file import delete_file -# from agents_api.queries.files.get_file import get_file -# from tests.fixtures import ( -# cozo_client, -# test_developer_id, -# test_file, -# ) - - -# @test("query: create file") -# def _(client=cozo_client, developer_id=test_developer_id): -# create_file( +from uuid_extensions import uuid7 +from ward import raises, test +from fastapi import HTTPException +from agents_api.autogen.openapi_model import CreateFileRequest +from agents_api.queries.files.create_file import create_file +from agents_api.queries.files.delete_file import delete_file +from agents_api.queries.files.get_file import get_file +from tests.fixtures import pg_dsn, test_agent, test_developer_id +from agents_api.clients.pg import create_db_pool + + +@test("query: create file") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + pool = await create_db_pool(dsn=dsn) + await create_file( + developer_id=developer_id, + data=CreateFileRequest( + name="Hello", + description="World", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + connection_pool=pool, + ) + + +# @test("query: get file") +# async def _(dsn=pg_dsn, developer_id=test_developer_id): +# pool = await create_db_pool(dsn=dsn) +# file = create_file( # developer_id=developer_id, # data=CreateFileRequest( # name="Hello", @@ -24,21 +38,20 @@ # mime_type="text/plain", # content="eyJzYW1wbGUiOiAidGVzdCJ9", # ), -# client=client, +# connection_pool=pool, # ) - -# @test("query: get file") -# def _(client=cozo_client, file=test_file, developer_id=test_developer_id): -# get_file( +# get_file_result = get_file( # developer_id=developer_id, # file_id=file.id, -# client=client, +# connection_pool=pool, # ) +# assert file == get_file_result # @test("query: delete file") -# def _(client=cozo_client, developer_id=test_developer_id): +# async def _(dsn=pg_dsn, developer_id=test_developer_id): +# pool = await create_db_pool(dsn=dsn) # file = create_file( # developer_id=developer_id, # data=CreateFileRequest( @@ -47,11 +60,21 @@ # mime_type="text/plain", # content="eyJzYW1wbGUiOiAidGVzdCJ9", # ), -# client=client, +# connection_pool=pool, # ) # delete_file( # developer_id=developer_id, # file_id=file.id, -# client=client, +# connection_pool=pool, # ) + +# with raises(HTTPException) as e: +# get_file( +# developer_id=developer_id, +# file_id=file.id, +# connection_pool=pool, +# ) + +# assert e.value.status_code == 404 +# assert e.value.detail == "The specified file does not exist" \ No newline at end of file From 47c3fc936349ebbc8b09850da14460d3fa6d2e2d Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Thu, 19 Dec 2024 04:44:24 +0000 Subject: [PATCH 3/7] refactor: Lint agents-api (CI) --- .../agents_api/queries/files/__init__.py | 7 +----- .../agents_api/queries/files/create_file.py | 13 ++++++----- .../agents_api/queries/files/delete_file.py | 22 +++++++++++-------- .../agents_api/queries/files/get_file.py | 9 ++++---- .../agents_api/queries/files/list_files.py | 10 +++++---- agents-api/tests/test_files_queries.py | 7 +++--- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/agents-api/agents_api/queries/files/__init__.py b/agents-api/agents_api/queries/files/__init__.py index 1da09114a..99670a8fc 100644 --- a/agents-api/agents_api/queries/files/__init__.py +++ b/agents-api/agents_api/queries/files/__init__.py @@ -13,9 +13,4 @@ from .get_file import get_file from .list_files import list_files -__all__ = [ - "create_file", - "delete_file", - "get_file", - "list_files" -] \ No newline at end of file +__all__ = ["create_file", "delete_file", "get_file", "list_files"] diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 77e065433..64527bc31 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -3,20 +3,20 @@ It includes functions to construct and execute SQL queries for inserting new file records. """ +import base64 +import hashlib from typing import Any, Literal from uuid import UUID +import asyncpg from beartype import beartype +from fastapi import HTTPException from sqlglot import parse_one from uuid_extensions import uuid7 -import asyncpg -from fastapi import HTTPException -import base64 -import hashlib from ...autogen.openapi_model import CreateFileRequest, File from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Create file file_query = parse_one(""" @@ -63,6 +63,7 @@ ON CONFLICT (developer_id, agent_id, file_id) DO NOTHING; -- Uses primary key index """).sql(pretty=True) + # Add error handling decorator # @rewrap_exceptions( # { @@ -147,4 +148,4 @@ async def create_file( else: # agent queries.append((agent_file_query, assoc_params)) - return queries \ No newline at end of file + return queries diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py index d37e6f3e8..99f57f5e0 100644 --- a/agents-api/agents_api/queries/files/delete_file.py +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -6,15 +6,15 @@ from typing import Literal from uuid import UUID -from beartype import beartype -from sqlglot import parse_one import asyncpg +from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Simple query to delete file (when no associations exist) delete_file_query = parse_one(""" @@ -67,7 +67,7 @@ ResourceDeletedResponse, one=True, transform=lambda d: { - "id": d["file_id"], + "id": d["file_id"], "deleted_at": utcnow(), "jobs": [], }, @@ -76,15 +76,15 @@ @pg_query @beartype async def delete_file( - *, - file_id: UUID, + *, + file_id: UUID, developer_id: UUID, owner_id: UUID | None = None, owner_type: Literal["user", "agent"] | None = None, ) -> list[tuple[str, list] | tuple[str, list, str]]: """ Deletes a file and/or its association using simple, efficient queries. - + If owner details provided: 1. Deletes the owner's association 2. Checks for remaining associations @@ -106,9 +106,13 @@ async def delete_file( if owner_id and owner_type: # Delete specific association assoc_params = [developer_id, file_id, owner_id] - assoc_query = delete_user_assoc_query if owner_type == "user" else delete_agent_assoc_query + assoc_query = ( + delete_user_assoc_query + if owner_type == "user" + else delete_agent_assoc_query + ) queries.append((assoc_query, assoc_params)) - + # If no associations, delete file queries.append((delete_file_query, [developer_id, file_id])) else: diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 3143b8ff0..8f04f8029 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -5,13 +5,13 @@ from uuid import UUID +import asyncpg from beartype import beartype -from sqlglot import parse_one from fastapi import HTTPException -import asyncpg +from sqlglot import parse_one from ...autogen.openapi_model import File -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Define the raw SQL query file_query = parse_one(""" @@ -31,6 +31,7 @@ LIMIT 1; -- Early termination once found """).sql(pretty=True) + @rewrap_exceptions( { asyncpg.NoDataFoundError: partialclass( @@ -66,4 +67,4 @@ async def get_file(*, file_id: UUID, developer_id: UUID) -> tuple[str, list]: return ( file_query, [developer_id, file_id], # Order matches index columns - ) \ No newline at end of file + ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index a01f74214..e6f65d88d 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -5,13 +5,14 @@ from typing import Any, Literal from uuid import UUID + import asyncpg from beartype import beartype from fastapi import HTTPException from sqlglot import parse_one from ...autogen.openapi_model import File -from ..utils import pg_query, rewrap_exceptions, wrap_in_class, partialclass +from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class # Query to list all files for a developer (uses developer_id index) developer_files_query = parse_one(""" @@ -92,8 +93,9 @@ OFFSET $3; """).sql(pretty=True) + @wrap_in_class( - File, + File, one=True, transform=lambda d: { **d, @@ -135,10 +137,10 @@ async def list_files( # Validate parameters if direction.lower() not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Invalid sort direction") - + if limit > 100 or limit < 1: raise HTTPException(status_code=400, detail="Limit must be between 1 and 100") - + if offset < 0: raise HTTPException(status_code=400, detail="Offset must be non-negative") diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 5565d4059..02ad888f5 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -1,15 +1,16 @@ # # Tests for entry queries +from fastapi import HTTPException from uuid_extensions import uuid7 from ward import raises, test -from fastapi import HTTPException + from agents_api.autogen.openapi_model import CreateFileRequest +from agents_api.clients.pg import create_db_pool from agents_api.queries.files.create_file import create_file from agents_api.queries.files.delete_file import delete_file from agents_api.queries.files.get_file import get_file from tests.fixtures import pg_dsn, test_agent, test_developer_id -from agents_api.clients.pg import create_db_pool @test("query: create file") @@ -77,4 +78,4 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): # ) # assert e.value.status_code == 404 -# assert e.value.detail == "The specified file does not exist" \ No newline at end of file +# assert e.value.detail == "The specified file does not exist" From cc2a5bf8aeda56016b647148a7155f4361f8f51f Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Thu, 19 Dec 2024 01:55:25 -0500 Subject: [PATCH 4/7] chore: bug fixes for file queries + added tests --- .../agents_api/queries/agents/delete_agent.py | 39 +- .../agents_api/queries/files/create_file.py | 58 +- .../agents_api/queries/files/delete_file.py | 113 ++-- .../agents_api/queries/files/get_file.py | 81 +-- .../agents_api/queries/files/list_files.py | 83 +-- .../agents_api/queries/users/delete_user.py | 35 +- agents-api/agents_api/queries/utils.py | 5 - agents-api/tests/fixtures.py | 20 +- agents-api/tests/test_agent_queries.py | 15 +- agents-api/tests/test_entry_queries.py | 318 +++++------ agents-api/tests/test_files_queries.py | 282 ++++++++-- agents-api/tests/test_session_queries.py | 522 +++++++++--------- 12 files changed, 868 insertions(+), 703 deletions(-) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index 6738374db..a957ab2c5 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -19,19 +19,39 @@ # Define the raw SQL query agent_query = parse_one(""" -WITH deleted_docs AS ( +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 + ) +), +deleted_docs AS ( DELETE FROM docs WHERE developer_id = $1 AND doc_id IN ( - SELECT ad.doc_id - FROM agent_docs ad - WHERE ad.agent_id = $2 - AND ad.developer_id = $1 + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'agent' + AND owner_id = $2 ) -), deleted_agent_docs AS ( - DELETE FROM agent_docs - WHERE agent_id = $2 AND developer_id = $1 -), deleted_tools AS ( +), +deleted_tools AS ( DELETE FROM tools WHERE agent_id = $2 AND developer_id = $1 ) @@ -40,7 +60,6 @@ RETURNING developer_id, agent_id; """).sql(pretty=True) - # @rewrap_exceptions( # @rewrap_exceptions( # { diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 64527bc31..8438978e6 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -27,7 +27,7 @@ description, mime_type, size, - hash, + hash ) VALUES ( $1, -- developer_id @@ -36,34 +36,28 @@ $4, -- description $5, -- mime_type $6, -- size - $7, -- hash + $7 -- hash ) RETURNING *; """).sql(pretty=True) -# Create user file association -user_file_query = parse_one(""" -INSERT INTO user_files ( - developer_id, - user_id, - file_id -) -VALUES ($1, $2, $3) -ON CONFLICT (developer_id, user_id, file_id) DO NOTHING; -- Uses primary key index -""").sql(pretty=True) - -# Create agent file association -agent_file_query = parse_one(""" -INSERT INTO agent_files ( - developer_id, - agent_id, - file_id +# Replace both user_file and agent_file queries with a single file_owner query +file_owner_query = parse_one(""" +WITH inserted_owner AS ( + INSERT INTO file_owners ( + developer_id, + file_id, + owner_type, + owner_id + ) + VALUES ($1, $2, $3, $4) + RETURNING file_id ) -VALUES ($1, $2, $3) -ON CONFLICT (developer_id, agent_id, file_id) DO NOTHING; -- Uses primary key index +SELECT f.* +FROM inserted_owner io +JOIN files f ON f.file_id = io.file_id; """).sql(pretty=True) - # Add error handling decorator # @rewrap_exceptions( # { @@ -90,6 +84,7 @@ transform=lambda d: { **d, "id": d["file_id"], + "hash": d["hash"].hex(), "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", }, ) @@ -121,8 +116,8 @@ async def create_file( # Calculate size and hash content_bytes = base64.b64decode(data.content) - data.size = len(content_bytes) - data.hash = hashlib.sha256(content_bytes).digest() + size = len(content_bytes) + hash_bytes = hashlib.sha256(content_bytes).digest() # Base file parameters file_params = [ @@ -131,21 +126,18 @@ async def create_file( data.name, data.description, data.mime_type, - data.size, - data.hash, + size, + hash_bytes, ] queries = [] - # Create the file + # Create the file first queries.append((file_query, file_params)) - # Create the association only if both owner_type and owner_id are provided + # Then create the association if owner info provided if owner_type and owner_id: - assoc_params = [developer_id, owner_id, file_id] - if owner_type == "user": - queries.append((user_file_query, assoc_params)) - else: # agent - queries.append((agent_file_query, assoc_params)) + assoc_params = [developer_id, file_id, owner_type, owner_id] + queries.append((file_owner_query, assoc_params)) return queries diff --git a/agents-api/agents_api/queries/files/delete_file.py b/agents-api/agents_api/queries/files/delete_file.py index 99f57f5e0..31cb43404 100644 --- a/agents-api/agents_api/queries/files/delete_file.py +++ b/agents-api/agents_api/queries/files/delete_file.py @@ -16,53 +16,40 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# Simple query to delete file (when no associations exist) +# Delete file query with ownership check delete_file_query = parse_one(""" +WITH deleted_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (owner_type = $3 AND owner_id = $4) + ) +) DELETE FROM files WHERE developer_id = $1 AND file_id = $2 -AND NOT EXISTS ( - SELECT 1 - FROM user_files uf - WHERE uf.file_id = $2 - LIMIT 1 -) -AND NOT EXISTS ( - SELECT 1 - FROM agent_files af - WHERE af.file_id = $2 - LIMIT 1 -) -RETURNING file_id; -""").sql(pretty=True) - -# Query to delete owner's association -delete_user_assoc_query = parse_one(""" -DELETE FROM user_files -WHERE developer_id = $1 -AND file_id = $2 -AND user_id = $3 -RETURNING file_id; -""").sql(pretty=True) - -delete_agent_assoc_query = parse_one(""" -DELETE FROM agent_files -WHERE developer_id = $1 -AND file_id = $2 -AND agent_id = $3 +AND ($3::text IS NULL OR EXISTS ( + SELECT 1 FROM file_owners + WHERE developer_id = $1 + AND file_id = $2 + AND owner_type = $3 + AND owner_id = $4 +)) RETURNING file_id; """).sql(pretty=True) -# @rewrap_exceptions( -# { -# asyncpg.NoDataFoundError: partialclass( -# HTTPException, -# status_code=404, -# detail="File not found", -# ), -# } -# ) +@rewrap_exceptions( + { + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="File not found", + ), + } +) @wrap_in_class( ResourceDeletedResponse, one=True, @@ -77,46 +64,24 @@ @beartype async def delete_file( *, - file_id: UUID, developer_id: UUID, - owner_id: UUID | None = None, + file_id: UUID, owner_type: Literal["user", "agent"] | None = None, -) -> list[tuple[str, list] | tuple[str, list, str]]: + owner_id: UUID | None = None, +) -> tuple[str, list]: """ - Deletes a file and/or its association using simple, efficient queries. - - If owner details provided: - 1. Deletes the owner's association - 2. Checks for remaining associations - 3. Deletes file if no associations remain - If no owner details: - - Deletes file only if it has no associations + Deletes a file and its ownership records. Args: - file_id (UUID): The UUID of the file to be deleted. - developer_id (UUID): The UUID of the developer owning the file. - owner_id (UUID | None): Optional owner ID to verify ownership - owner_type (str | None): Optional owner type to verify ownership + developer_id: The developer's UUID + file_id: The file's UUID + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner Returns: - list[tuple[str, list] | tuple[str, list, str]]: List of SQL queries, their parameters, and fetch type + tuple[str, list]: SQL query and parameters """ - queries = [] - - if owner_id and owner_type: - # Delete specific association - assoc_params = [developer_id, file_id, owner_id] - assoc_query = ( - delete_user_assoc_query - if owner_type == "user" - else delete_agent_assoc_query - ) - queries.append((assoc_query, assoc_params)) - - # If no associations, delete file - queries.append((delete_file_query, [developer_id, file_id])) - else: - # Try to delete file if it has no associations - queries.append((delete_file_query, [developer_id, file_id])) - - return queries + return ( + delete_file_query, + [developer_id, file_id, owner_type, owner_id], + ) diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index 8f04f8029..ace417d5d 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -4,6 +4,7 @@ """ from uuid import UUID +from typing import Literal import asyncpg from beartype import beartype @@ -15,56 +16,66 @@ # Define the raw SQL query file_query = parse_one(""" -SELECT - file_id, -- Only select needed columns - developer_id, - name, - description, - mime_type, - size, - hash, - created_at, - updated_at -FROM files -WHERE developer_id = $1 -- Order matches composite index (developer_id, file_id) - AND file_id = $2 -- Using both parts of the index -LIMIT 1; -- Early termination once found +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 +AND f.file_id = $2 +AND ( + ($3::text IS NULL AND $4::uuid IS NULL) OR + (fo.owner_type = $3 AND fo.owner_id = $4) +) +LIMIT 1; """).sql(pretty=True) -@rewrap_exceptions( - { - asyncpg.NoDataFoundError: partialclass( - HTTPException, - status_code=404, - detail="File not found", - ), - asyncpg.ForeignKeyViolationError: partialclass( - HTTPException, - status_code=404, - detail="Developer not found", - ), +# @rewrap_exceptions( +# { +# asyncpg.NoDataFoundError: partialclass( +# HTTPException, +# status_code=404, +# detail="File not found", +# ), +# asyncpg.ForeignKeyViolationError: partialclass( +# HTTPException, +# status_code=404, +# detail="Developer not found", +# ), +# } +# ) +@wrap_in_class( + File, + one=True, + transform=lambda d: { + "id": d["file_id"], + **d, + "hash": d["hash"].hex(), + "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", } ) -@wrap_in_class(File, one=True, transform=lambda d: {"id": d["file_id"], **d}) @pg_query @beartype -async def get_file(*, file_id: UUID, developer_id: UUID) -> tuple[str, list]: +async def get_file( + *, + file_id: UUID, + developer_id: UUID, + owner_type: Literal["user", "agent"] | None = None, + owner_id: UUID | None = None, +) -> tuple[str, list]: """ Constructs the SQL query to retrieve a file's details. Uses composite index on (developer_id, file_id) for efficient lookup. Args: - file_id (UUID): The UUID of the file to retrieve. - developer_id (UUID): The UUID of the developer owning the file. + file_id: The UUID of the file to retrieve + developer_id: The UUID of the developer owning the file + owner_type: Optional type of owner ("user" or "agent") + owner_id: Optional UUID of the owner Returns: - tuple[str, list]: A tuple containing the SQL query and its parameters. - - Raises: - HTTPException: If file or developer not found (404) + tuple[str, list]: SQL query and parameters """ return ( file_query, - [developer_id, file_id], # Order matches index columns + [developer_id, file_id, owner_type, owner_id], ) diff --git a/agents-api/agents_api/queries/files/list_files.py b/agents-api/agents_api/queries/files/list_files.py index e6f65d88d..2bc42f842 100644 --- a/agents-api/agents_api/queries/files/list_files.py +++ b/agents-api/agents_api/queries/files/list_files.py @@ -16,18 +16,10 @@ # Query to list all files for a developer (uses developer_id index) developer_files_query = parse_one(""" -SELECT - file_id, - developer_id, - name, - description, - mime_type, - size, - hash, - created_at, - updated_at -FROM files -WHERE developer_id = $1 +SELECT f.* +FROM files f +LEFT JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE f.developer_id = $1 ORDER BY CASE WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at @@ -39,55 +31,20 @@ OFFSET $3; """).sql(pretty=True) -# Query to list files for a specific user (uses composite indexes) -user_files_query = parse_one(""" -SELECT - f.file_id, - f.developer_id, - f.name, - f.description, - f.mime_type, - f.size, - f.hash, - f.created_at, - f.updated_at -FROM user_files uf -JOIN files f USING (developer_id, file_id) -WHERE uf.developer_id = $1 -AND uf.user_id = $6 +# Query to list files for a specific owner (uses composite indexes) +owner_files_query = parse_one(""" +SELECT f.* +FROM files f +JOIN file_owners fo ON f.developer_id = fo.developer_id AND f.file_id = fo.file_id +WHERE fo.developer_id = $1 +AND fo.owner_id = $6 +AND fo.owner_type = $7 ORDER BY CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN f.created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN f.created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN f.updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN f.updated_at - END DESC NULLS LAST -LIMIT $2 -OFFSET $3; -""").sql(pretty=True) - -# Query to list files for a specific agent (uses composite indexes) -agent_files_query = parse_one(""" -SELECT - f.file_id, - f.developer_id, - f.name, - f.description, - f.mime_type, - f.size, - f.hash, - f.created_at, - f.updated_at -FROM agent_files af -JOIN files f USING (developer_id, file_id) -WHERE af.developer_id = $1 -AND af.agent_id = $6 -ORDER BY - CASE - WHEN $4 = 'created_at' AND $5 = 'asc' THEN f.created_at - WHEN $4 = 'created_at' AND $5 = 'desc' THEN f.created_at - WHEN $4 = 'updated_at' AND $5 = 'asc' THEN f.updated_at - WHEN $4 = 'updated_at' AND $5 = 'desc' THEN f.updated_at + WHEN $4 = 'created_at' AND $5 = 'asc' THEN created_at + WHEN $4 = 'created_at' AND $5 = 'desc' THEN created_at + WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at + WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at END DESC NULLS LAST LIMIT $2 OFFSET $3; @@ -96,9 +53,11 @@ @wrap_in_class( File, - one=True, + one=False, transform=lambda d: { **d, + "id": d["file_id"], + "hash": d["hash"].hex(), "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", }, ) @@ -155,8 +114,8 @@ async def list_files( # Choose appropriate query based on owner details if owner_id and owner_type: - params.append(owner_id) # Add owner_id as $6 - query = user_files_query if owner_type == "user" else agent_files_query + params.extend([owner_id, owner_type]) # Add owner_id as $6 and owner_type as $7 + query = owner_files_query # Use single query with owner_type parameter else: query = developer_files_query diff --git a/agents-api/agents_api/queries/users/delete_user.py b/agents-api/agents_api/queries/users/delete_user.py index 86bcc0b26..ad5befd73 100644 --- a/agents-api/agents_api/queries/users/delete_user.py +++ b/agents-api/agents_api/queries/users/delete_user.py @@ -11,14 +11,37 @@ # Define the raw SQL query outside the function delete_query = parse_one(""" -WITH deleted_data AS ( - DELETE FROM user_files -- user_files - WHERE developer_id = $1 -- developer_id - AND user_id = $2 -- user_id +WITH deleted_file_owners AS ( + DELETE FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_doc_owners AS ( + DELETE FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 +), +deleted_files AS ( + DELETE FROM files + WHERE developer_id = $1 + AND file_id IN ( + SELECT file_id FROM file_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) ), deleted_docs AS ( - DELETE FROM user_docs - WHERE developer_id = $1 AND user_id = $2 + DELETE FROM docs + WHERE developer_id = $1 + AND doc_id IN ( + SELECT doc_id FROM doc_owners + WHERE developer_id = $1 + AND owner_type = 'user' + AND owner_id = $2 + ) ) DELETE FROM users WHERE developer_id = $1 AND user_id = $2 diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index 73113580d..e9cca6e95 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -170,11 +170,6 @@ async def wrapper( query, *args, timeout=timeout ) - print("%" * 100) - print(results) - print(*args) - print("%" * 100) - if method_name == "fetchrow" and ( len(results) == 0 or results.get("bool") is None ): diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 9153785a4..2cad999e8 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -11,6 +11,7 @@ CreateAgentRequest, CreateSessionRequest, CreateUserRequest, + CreateFileRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode @@ -25,7 +26,7 @@ # from agents_api.queries.execution.create_execution import create_execution # from agents_api.queries.execution.create_execution_transition import create_execution_transition # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup -# from agents_api.queries.files.create_file import create_file +from agents_api.queries.files.create_file import create_file # from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session @@ -132,6 +133,23 @@ async def test_user(dsn=pg_dsn, developer=test_developer): return user +@fixture(scope="test") +async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Hello", + description="World", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + connection_pool=pool, + ) + + return file + + @fixture(scope="test") async def random_email(): return f"{"".join([random.choice(string.ascii_lowercase) for _ in range(10)])}@mail.com" diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index b6cb7aedc..9192773ab 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -143,12 +143,21 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): @test("query: delete agent sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): +async def _(dsn=pg_dsn, developer_id=test_developer_id): """Test that an agent can be successfully deleted.""" pool = await create_db_pool(dsn=dsn) + create_result = await create_agent( + developer_id=developer_id, + data=CreateAgentRequest( + name="test agent", + about="test agent about", + model="gpt-4o-mini", + ), + connection_pool=pool, + ) delete_result = await delete_agent( - agent_id=agent.id, developer_id=developer_id, connection_pool=pool + agent_id=create_result.id, developer_id=developer_id, connection_pool=pool ) assert delete_result is not None @@ -157,6 +166,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent): with raises(Exception): await get_agent( developer_id=developer_id, - agent_id=agent.id, + agent_id=create_result.id, connection_pool=pool, ) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 60a387591..eab6bb718 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -1,177 +1,177 @@ -""" -This module contains tests for entry queries against the CozoDB database. -It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. -""" - -from fastapi import HTTPException -from uuid_extensions import uuid7 -from ward import raises, test - -from agents_api.autogen.openapi_model import CreateEntryRequest -from agents_api.clients.pg import create_db_pool -from agents_api.queries.entries import create_entries, list_entries -from tests.fixtures import pg_dsn, test_developer, test_session # , test_session - -MODEL = "gpt-4o-mini" - - -@test("query: create entry no session") -async def _(dsn=pg_dsn, developer=test_developer): - """Test the addition of a new entry to the database.""" - - pool = await create_db_pool(dsn=dsn) - test_entry = CreateEntryRequest.from_model_input( - model=MODEL, - role="user", - source="internal", - content="test entry content", - ) - - with raises(HTTPException) as exc_info: - await create_entries( - developer_id=developer.id, - session_id=uuid7(), - data=[test_entry], - connection_pool=pool, - ) - assert exc_info.raised.status_code == 404 - - -@test("query: list entries no session") -async def _(dsn=pg_dsn, developer=test_developer): - """Test the retrieval of entries from the database.""" - - pool = await create_db_pool(dsn=dsn) - - with raises(HTTPException) as exc_info: - await list_entries( - developer_id=developer.id, - session_id=uuid7(), - connection_pool=pool, - ) - assert exc_info.raised.status_code == 404 - - -# @test("query: get entries") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the retrieval of entries from the database.""" +# """ +# This module contains tests for entry queries against the CozoDB database. +# It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. +# """ -# pool = await create_db_pool(dsn=dsn) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="api_request", -# content="test entry content", -# ) +# from fastapi import HTTPException +# from uuid_extensions import uuid7 +# from ward import raises, test -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="test entry content", -# source="internal", -# ) - -# await create_entries( -# developer_id=TEST_DEVELOPER_ID, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) - -# result = await list_entries( -# developer_id=TEST_DEVELOPER_ID, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) +# from agents_api.autogen.openapi_model import CreateEntryRequest +# from agents_api.clients.pg import create_db_pool +# from agents_api.queries.entries import create_entries, list_entries +# from tests.fixtures import pg_dsn, test_developer, test_session # , test_session +# MODEL = "gpt-4o-mini" -# # Assert that only one entry is retrieved, matching the session_id. -# assert len(result) == 1 -# assert isinstance(result[0], Entry) -# assert result is not None - -# @test("query: get history") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the retrieval of entry history from the database.""" +# @test("query: create entry no session") +# async def _(dsn=pg_dsn, developer=test_developer): +# """Test the addition of a new entry to the database.""" # pool = await create_db_pool(dsn=dsn) # test_entry = CreateEntryRequest.from_model_input( # model=MODEL, # role="user", -# source="api_request", -# content="test entry content", -# ) - -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="test entry content", # source="internal", +# content="test entry content", # ) -# await create_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) - -# result = await get_history( -# developer_id=developer_id, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) - -# # Assert that entries are retrieved and have valid IDs. -# assert result is not None -# assert isinstance(result, History) -# assert len(result.entries) > 0 -# assert result.entries[0].id +# with raises(HTTPException) as exc_info: +# await create_entries( +# developer_id=developer.id, +# session_id=uuid7(), +# data=[test_entry], +# connection_pool=pool, +# ) +# assert exc_info.raised.status_code == 404 -# @test("query: delete entries") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session -# """Test the deletion of entries from the database.""" +# @test("query: list entries no session") +# async def _(dsn=pg_dsn, developer=test_developer): +# """Test the retrieval of entries from the database.""" # pool = await create_db_pool(dsn=dsn) -# test_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# source="api_request", -# content="test entry content", -# ) - -# internal_entry = CreateEntryRequest.from_model_input( -# model=MODEL, -# role="user", -# content="internal entry content", -# source="internal", -# ) - -# created_entries = await create_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# data=[test_entry, internal_entry], -# connection_pool=pool, -# ) -# entry_ids = [entry.id for entry in created_entries] - -# await delete_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], -# connection_pool=pool, -# ) - -# result = await list_entries( -# developer_id=developer_id, -# session_id=SESSION_ID, -# connection_pool=pool, -# ) - -# Assert that no entries are retrieved after deletion. -# assert all(id not in [entry.id for entry in result] for id in entry_ids) -# assert len(result) == 0 -# assert result is not None +# with raises(HTTPException) as exc_info: +# await list_entries( +# developer_id=developer.id, +# session_id=uuid7(), +# connection_pool=pool, +# ) +# assert exc_info.raised.status_code == 404 + + +# # @test("query: get entries") +# # async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# # """Test the retrieval of entries from the database.""" + +# # pool = await create_db_pool(dsn=dsn) +# # test_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # source="api_request", +# # content="test entry content", +# # ) + +# # internal_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # content="test entry content", +# # source="internal", +# # ) + +# # await create_entries( +# # developer_id=TEST_DEVELOPER_ID, +# # session_id=SESSION_ID, +# # data=[test_entry, internal_entry], +# # connection_pool=pool, +# # ) + +# # result = await list_entries( +# # developer_id=TEST_DEVELOPER_ID, +# # session_id=SESSION_ID, +# # connection_pool=pool, +# # ) + + +# # # Assert that only one entry is retrieved, matching the session_id. +# # assert len(result) == 1 +# # assert isinstance(result[0], Entry) +# # assert result is not None + + +# # @test("query: get history") +# # async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# # """Test the retrieval of entry history from the database.""" + +# # pool = await create_db_pool(dsn=dsn) +# # test_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # source="api_request", +# # content="test entry content", +# # ) + +# # internal_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # content="test entry content", +# # source="internal", +# # ) + +# # await create_entries( +# # developer_id=developer_id, +# # session_id=SESSION_ID, +# # data=[test_entry, internal_entry], +# # connection_pool=pool, +# # ) + +# # result = await get_history( +# # developer_id=developer_id, +# # session_id=SESSION_ID, +# # connection_pool=pool, +# # ) + +# # # Assert that entries are retrieved and have valid IDs. +# # assert result is not None +# # assert isinstance(result, History) +# # assert len(result.entries) > 0 +# # assert result.entries[0].id + + +# # @test("query: delete entries") +# # async def _(dsn=pg_dsn, developer_id=test_developer_id): # , session=test_session +# # """Test the deletion of entries from the database.""" + +# # pool = await create_db_pool(dsn=dsn) +# # test_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # source="api_request", +# # content="test entry content", +# # ) + +# # internal_entry = CreateEntryRequest.from_model_input( +# # model=MODEL, +# # role="user", +# # content="internal entry content", +# # source="internal", +# # ) + +# # created_entries = await create_entries( +# # developer_id=developer_id, +# # session_id=SESSION_ID, +# # data=[test_entry, internal_entry], +# # connection_pool=pool, +# # ) + +# # entry_ids = [entry.id for entry in created_entries] + +# # await delete_entries( +# # developer_id=developer_id, +# # session_id=SESSION_ID, +# # entry_ids=[UUID("123e4567-e89b-12d3-a456-426614174002")], +# # connection_pool=pool, +# # ) + +# # result = await list_entries( +# # developer_id=developer_id, +# # session_id=SESSION_ID, +# # connection_pool=pool, +# # ) + +# # Assert that no entries are retrieved after deletion. +# # assert all(id not in [entry.id for entry in result] for id in entry_ids) +# # assert len(result) == 0 +# # assert result is not None diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index 02ad888f5..dd21be82b 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -10,14 +10,15 @@ from agents_api.queries.files.create_file import create_file from agents_api.queries.files.delete_file import delete_file from agents_api.queries.files.get_file import get_file -from tests.fixtures import pg_dsn, test_agent, test_developer_id +from agents_api.queries.files.list_files import list_files +from tests.fixtures import pg_dsn, test_developer, test_file, test_agent, test_user @test("query: create file") -async def _(dsn=pg_dsn, developer_id=test_developer_id): +async def _(dsn=pg_dsn, developer=test_developer): pool = await create_db_pool(dsn=dsn) await create_file( - developer_id=developer_id, + developer_id=developer.id, data=CreateFileRequest( name="Hello", description="World", @@ -28,54 +29,227 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -# @test("query: get file") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): -# pool = await create_db_pool(dsn=dsn) -# file = create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# connection_pool=pool, -# ) - -# get_file_result = get_file( -# developer_id=developer_id, -# file_id=file.id, -# connection_pool=pool, -# ) - -# assert file == get_file_result - -# @test("query: delete file") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): -# pool = await create_db_pool(dsn=dsn) -# file = create_file( -# developer_id=developer_id, -# data=CreateFileRequest( -# name="Hello", -# description="World", -# mime_type="text/plain", -# content="eyJzYW1wbGUiOiAidGVzdCJ9", -# ), -# connection_pool=pool, -# ) - -# delete_file( -# developer_id=developer_id, -# file_id=file.id, -# connection_pool=pool, -# ) - -# with raises(HTTPException) as e: -# get_file( -# developer_id=developer_id, -# file_id=file.id, -# connection_pool=pool, -# ) - -# assert e.value.status_code == 404 -# assert e.value.detail == "The specified file does not exist" +@test("query: create user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User File", + description="Test user file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert file.name == "User File" + + # Verify file appears in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + + +@test("query: create agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent File", + description="Test agent file", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert file.name == "Agent File" + + # Verify file appears in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert any(f.id == file.id for f in files) + + +@test("model: get file") +async def _(dsn=pg_dsn, file=test_file, developer=test_developer): + pool = await create_db_pool(dsn=dsn) + file_test = await get_file( + developer_id=developer.id, + file_id=file.id, + connection_pool=pool, + ) + assert file_test.id == file.id + assert file_test.name == "Hello" + assert file_test.description == "World" + assert file_test.mime_type == "text/plain" + assert file_test.hash == file.hash + + +@test("query: list files") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) + files = await list_files( + developer_id=developer.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: list user files") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User List Test", + description="Test file for user listing", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # List user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: list agent files") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent List Test", + description="Test file for agent listing", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # List agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert len(files) >= 1 + assert any(f.id == file.id for f in files) + + +@test("query: delete user file") +async def _(dsn=pg_dsn, developer=test_developer, user=test_user): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the user + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="User Delete Test", + description="Test file for user deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + + # Verify file is no longer in user's files + files = await list_files( + developer_id=developer.id, + owner_type="user", + owner_id=user.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete agent file") +async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): + pool = await create_db_pool(dsn=dsn) + + # Create a file owned by the agent + file = await create_file( + developer_id=developer.id, + data=CreateFileRequest( + name="Agent Delete Test", + description="Test file for agent deletion", + mime_type="text/plain", + content="eyJzYW1wbGUiOiAidGVzdCJ9", + ), + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Delete the file + await delete_file( + developer_id=developer.id, + file_id=file.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + + # Verify file is no longer in agent's files + files = await list_files( + developer_id=developer.id, + owner_type="agent", + owner_id=agent.id, + connection_pool=pool, + ) + assert not any(f.id == file.id for f in files) + + +@test("query: delete file") +async def _(dsn=pg_dsn, developer=test_developer, file=test_file): + pool = await create_db_pool(dsn=dsn) + + await delete_file( + developer_id=developer.id, + file_id=file.id, + connection_pool=pool, + ) + + diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 8e512379f..199382775 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -1,261 +1,261 @@ -""" -This module contains tests for SQL query generation functions in the sessions module. -Tests verify the SQL queries without actually executing them against a database. -""" - -from uuid_extensions import uuid7 -from ward import raises, test - -from agents_api.autogen.openapi_model import ( - CreateOrUpdateSessionRequest, - CreateSessionRequest, - PatchSessionRequest, - ResourceDeletedResponse, - ResourceUpdatedResponse, - Session, - UpdateSessionRequest, -) -from agents_api.clients.pg import create_db_pool -from agents_api.queries.sessions import ( - count_sessions, - create_or_update_session, - create_session, - delete_session, - get_session, - list_sessions, - patch_session, - update_session, -) -from tests.fixtures import ( - pg_dsn, - test_agent, - test_developer, - test_developer_id, - test_session, - test_user, -) - - -@test("query: create session sql") -async def _( - dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user -): - """Test that a session can be successfully created.""" - - pool = await create_db_pool(dsn=dsn) - session_id = uuid7() - data = CreateSessionRequest( - users=[user.id], - agents=[agent.id], - situation="test session", - system_template="test system template", - ) - result = await create_session( - developer_id=developer_id, - session_id=session_id, - data=data, - connection_pool=pool, - ) - - assert result is not None - assert isinstance(result, Session), f"Result is not a Session, {result}" - assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} - - -@test("query: create or update session sql") -async def _( - dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user -): - """Test that a session can be successfully created or updated.""" - - pool = await create_db_pool(dsn=dsn) - session_id = uuid7() - data = CreateOrUpdateSessionRequest( - users=[user.id], - agents=[agent.id], - situation="test session", - ) - result = await create_or_update_session( - developer_id=developer_id, - session_id=session_id, - data=data, - connection_pool=pool, - ) - - assert result is not None - assert isinstance(result, Session) - assert result.id == session_id - assert result.developer_id == developer_id - assert result.situation == "test session" - assert set(result.users) == {user.id} - assert set(result.agents) == {agent.id} - - -@test("query: get session exists") -async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): - """Test retrieving an existing session.""" - - pool = await create_db_pool(dsn=dsn) - result = await get_session( - developer_id=developer_id, - session_id=session.id, - connection_pool=pool, - ) - - assert result is not None - assert isinstance(result, Session) - assert result.id == session.id - assert result.developer_id == developer_id - - -@test("query: get session does not exist") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test retrieving a non-existent session.""" - - session_id = uuid7() - pool = await create_db_pool(dsn=dsn) - with raises(Exception): - await get_session( - session_id=session_id, - developer_id=developer_id, - connection_pool=pool, - ) - - -@test("query: list sessions") -async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): - """Test listing sessions with default pagination.""" - - pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( - developer_id=developer_id, - limit=10, - offset=0, - connection_pool=pool, - ) - - assert isinstance(result, list) - assert len(result) >= 1 - assert any(s.id == session.id for s in result) - - -@test("query: list sessions with filters") -async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): - """Test listing sessions with specific filters.""" - - pool = await create_db_pool(dsn=dsn) - result, _ = await list_sessions( - developer_id=developer_id, - limit=10, - offset=0, - filters={"situation": "test session"}, - connection_pool=pool, - ) - - assert isinstance(result, list) - assert len(result) >= 1 - assert all(s.situation == "test session" for s in result) - - -@test("query: count sessions") -async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): - """Test counting the number of sessions for a developer.""" - - pool = await create_db_pool(dsn=dsn) - count = await count_sessions( - developer_id=developer_id, - connection_pool=pool, - ) - - assert isinstance(count, int) - assert count >= 1 - - -@test("query: update session sql") -async def _( - dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent -): - """Test that an existing session's information can be successfully updated.""" - - pool = await create_db_pool(dsn=dsn) - data = UpdateSessionRequest( - agents=[agent.id], - situation="updated session", - ) - result = await update_session( - session_id=session.id, - developer_id=developer_id, - data=data, - connection_pool=pool, - ) - - assert result is not None - assert isinstance(result, ResourceUpdatedResponse) - assert result.updated_at > session.created_at - - updated_session = await get_session( - developer_id=developer_id, - session_id=session.id, - connection_pool=pool, - ) - assert updated_session.situation == "updated session" - assert set(updated_session.agents) == {agent.id} - - -@test("query: patch session sql") -async def _( - dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent -): - """Test that a session can be successfully patched.""" - - pool = await create_db_pool(dsn=dsn) - data = PatchSessionRequest( - agents=[agent.id], - situation="patched session", - metadata={"test": "metadata"}, - ) - result = await patch_session( - developer_id=developer_id, - session_id=session.id, - data=data, - connection_pool=pool, - ) - - assert result is not None - assert isinstance(result, ResourceUpdatedResponse) - assert result.updated_at > session.created_at - - patched_session = await get_session( - developer_id=developer_id, - session_id=session.id, - connection_pool=pool, - ) - assert patched_session.situation == "patched session" - assert set(patched_session.agents) == {agent.id} - assert patched_session.metadata == {"test": "metadata"} - - -@test("query: delete session sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): - """Test that a session can be successfully deleted.""" - - pool = await create_db_pool(dsn=dsn) - delete_result = await delete_session( - developer_id=developer_id, - session_id=session.id, - connection_pool=pool, - ) - - assert delete_result is not None - assert isinstance(delete_result, ResourceDeletedResponse) - - with raises(Exception): - await get_session( - developer_id=developer_id, - session_id=session.id, - connection_pool=pool, - ) +# """ +# This module contains tests for SQL query generation functions in the sessions module. +# Tests verify the SQL queries without actually executing them against a database. +# """ + +# from uuid_extensions import uuid7 +# from ward import raises, test + +# from agents_api.autogen.openapi_model import ( +# CreateOrUpdateSessionRequest, +# CreateSessionRequest, +# PatchSessionRequest, +# ResourceDeletedResponse, +# ResourceUpdatedResponse, +# Session, +# UpdateSessionRequest, +# ) +# from agents_api.clients.pg import create_db_pool +# from agents_api.queries.sessions import ( +# count_sessions, +# create_or_update_session, +# create_session, +# delete_session, +# get_session, +# list_sessions, +# patch_session, +# update_session, +# ) +# from tests.fixtures import ( +# pg_dsn, +# test_agent, +# test_developer, +# test_developer_id, +# test_session, +# test_user, +# ) + + +# @test("query: create session sql") +# async def _( +# dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +# ): +# """Test that a session can be successfully created.""" + +# pool = await create_db_pool(dsn=dsn) +# session_id = uuid7() +# data = CreateSessionRequest( +# users=[user.id], +# agents=[agent.id], +# situation="test session", +# system_template="test system template", +# ) +# result = await create_session( +# developer_id=developer_id, +# session_id=session_id, +# data=data, +# connection_pool=pool, +# ) + +# assert result is not None +# assert isinstance(result, Session), f"Result is not a Session, {result}" +# assert result.id == session_id +# assert result.developer_id == developer_id +# assert result.situation == "test session" +# assert set(result.users) == {user.id} +# assert set(result.agents) == {agent.id} + + +# @test("query: create or update session sql") +# async def _( +# dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +# ): +# """Test that a session can be successfully created or updated.""" + +# pool = await create_db_pool(dsn=dsn) +# session_id = uuid7() +# data = CreateOrUpdateSessionRequest( +# users=[user.id], +# agents=[agent.id], +# situation="test session", +# ) +# result = await create_or_update_session( +# developer_id=developer_id, +# session_id=session_id, +# data=data, +# connection_pool=pool, +# ) + +# assert result is not None +# assert isinstance(result, Session) +# assert result.id == session_id +# assert result.developer_id == developer_id +# assert result.situation == "test session" +# assert set(result.users) == {user.id} +# assert set(result.agents) == {agent.id} + + +# @test("query: get session exists") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test retrieving an existing session.""" + +# pool = await create_db_pool(dsn=dsn) +# result = await get_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) + +# assert result is not None +# assert isinstance(result, Session) +# assert result.id == session.id +# assert result.developer_id == developer_id + + +# @test("query: get session does not exist") +# async def _(dsn=pg_dsn, developer_id=test_developer_id): +# """Test retrieving a non-existent session.""" + +# session_id = uuid7() +# pool = await create_db_pool(dsn=dsn) +# with raises(Exception): +# await get_session( +# session_id=session_id, +# developer_id=developer_id, +# connection_pool=pool, +# ) + + +# @test("query: list sessions") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test listing sessions with default pagination.""" + +# pool = await create_db_pool(dsn=dsn) +# result, _ = await list_sessions( +# developer_id=developer_id, +# limit=10, +# offset=0, +# connection_pool=pool, +# ) + +# assert isinstance(result, list) +# assert len(result) >= 1 +# assert any(s.id == session.id for s in result) + + +# @test("query: list sessions with filters") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test listing sessions with specific filters.""" + +# pool = await create_db_pool(dsn=dsn) +# result, _ = await list_sessions( +# developer_id=developer_id, +# limit=10, +# offset=0, +# filters={"situation": "test session"}, +# connection_pool=pool, +# ) + +# assert isinstance(result, list) +# assert len(result) >= 1 +# assert all(s.situation == "test session" for s in result) + + +# @test("query: count sessions") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test counting the number of sessions for a developer.""" + +# pool = await create_db_pool(dsn=dsn) +# count = await count_sessions( +# developer_id=developer_id, +# connection_pool=pool, +# ) + +# assert isinstance(count, int) +# assert count >= 1 + + +# @test("query: update session sql") +# async def _( +# dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +# ): +# """Test that an existing session's information can be successfully updated.""" + +# pool = await create_db_pool(dsn=dsn) +# data = UpdateSessionRequest( +# agents=[agent.id], +# situation="updated session", +# ) +# result = await update_session( +# session_id=session.id, +# developer_id=developer_id, +# data=data, +# connection_pool=pool, +# ) + +# assert result is not None +# assert isinstance(result, ResourceUpdatedResponse) +# assert result.updated_at > session.created_at + +# updated_session = await get_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) +# assert updated_session.situation == "updated session" +# assert set(updated_session.agents) == {agent.id} + + +# @test("query: patch session sql") +# async def _( +# dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +# ): +# """Test that a session can be successfully patched.""" + +# pool = await create_db_pool(dsn=dsn) +# data = PatchSessionRequest( +# agents=[agent.id], +# situation="patched session", +# metadata={"test": "metadata"}, +# ) +# result = await patch_session( +# developer_id=developer_id, +# session_id=session.id, +# data=data, +# connection_pool=pool, +# ) + +# assert result is not None +# assert isinstance(result, ResourceUpdatedResponse) +# assert result.updated_at > session.created_at + +# patched_session = await get_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) +# assert patched_session.situation == "patched session" +# assert set(patched_session.agents) == {agent.id} +# assert patched_session.metadata == {"test": "metadata"} + + +# @test("query: delete session sql") +# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): +# """Test that a session can be successfully deleted.""" + +# pool = await create_db_pool(dsn=dsn) +# delete_result = await delete_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) + +# assert delete_result is not None +# assert isinstance(delete_result, ResourceDeletedResponse) + +# with raises(Exception): +# await get_session( +# developer_id=developer_id, +# session_id=session.id, +# connection_pool=pool, +# ) From f974fa0f38bba27c8faafaf50f2a6f1476efd334 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Thu, 19 Dec 2024 06:56:33 +0000 Subject: [PATCH 5/7] refactor: Lint agents-api (CI) --- .../agents_api/queries/agents/delete_agent.py | 1 + .../agents_api/queries/files/create_file.py | 1 + .../agents_api/queries/files/get_file.py | 12 ++++---- agents-api/tests/fixtures.py | 3 +- agents-api/tests/test_files_queries.py | 30 +++++++++---------- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index a957ab2c5..d47711345 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -60,6 +60,7 @@ RETURNING developer_id, agent_id; """).sql(pretty=True) + # @rewrap_exceptions( # @rewrap_exceptions( # { diff --git a/agents-api/agents_api/queries/files/create_file.py b/agents-api/agents_api/queries/files/create_file.py index 8438978e6..48251fa5e 100644 --- a/agents-api/agents_api/queries/files/create_file.py +++ b/agents-api/agents_api/queries/files/create_file.py @@ -58,6 +58,7 @@ JOIN files f ON f.file_id = io.file_id; """).sql(pretty=True) + # Add error handling decorator # @rewrap_exceptions( # { diff --git a/agents-api/agents_api/queries/files/get_file.py b/agents-api/agents_api/queries/files/get_file.py index ace417d5d..4d5dca4c0 100644 --- a/agents-api/agents_api/queries/files/get_file.py +++ b/agents-api/agents_api/queries/files/get_file.py @@ -3,8 +3,8 @@ It constructs and executes SQL queries to fetch file details based on file ID and developer ID. """ -from uuid import UUID from typing import Literal +from uuid import UUID import asyncpg from beartype import beartype @@ -44,20 +44,20 @@ # } # ) @wrap_in_class( - File, - one=True, + File, + one=True, transform=lambda d: { "id": d["file_id"], **d, "hash": d["hash"].hex(), "content": "DUMMY: NEED TO FETCH CONTENT FROM BLOB STORAGE", - } + }, ) @pg_query @beartype async def get_file( - *, - file_id: UUID, + *, + file_id: UUID, developer_id: UUID, owner_type: Literal["user", "agent"] | None = None, owner_id: UUID | None = None, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 2cad999e8..0c904b383 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -9,9 +9,9 @@ from agents_api.autogen.openapi_model import ( CreateAgentRequest, + CreateFileRequest, CreateSessionRequest, CreateUserRequest, - CreateFileRequest, ) from agents_api.clients.pg import create_db_pool from agents_api.env import api_key, api_key_header_name, multi_tenant_mode @@ -27,6 +27,7 @@ # from agents_api.queries.execution.create_execution_transition import create_execution_transition # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup from agents_api.queries.files.create_file import create_file + # from agents_api.queries.files.delete_file import delete_file from agents_api.queries.sessions.create_session import create_session diff --git a/agents-api/tests/test_files_queries.py b/agents-api/tests/test_files_queries.py index dd21be82b..92b52d733 100644 --- a/agents-api/tests/test_files_queries.py +++ b/agents-api/tests/test_files_queries.py @@ -11,7 +11,7 @@ from agents_api.queries.files.delete_file import delete_file from agents_api.queries.files.get_file import get_file from agents_api.queries.files.list_files import list_files -from tests.fixtures import pg_dsn, test_developer, test_file, test_agent, test_user +from tests.fixtures import pg_dsn, test_agent, test_developer, test_file, test_user @test("query: create file") @@ -45,7 +45,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): connection_pool=pool, ) assert file.name == "User File" - + # Verify file appears in user's files files = await list_files( developer_id=developer.id, @@ -59,7 +59,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): @test("query: create agent file") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) - + file = await create_file( developer_id=developer.id, data=CreateFileRequest( @@ -73,7 +73,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): connection_pool=pool, ) assert file.name == "Agent File" - + # Verify file appears in agent's files files = await list_files( developer_id=developer.id, @@ -113,7 +113,7 @@ async def _(dsn=pg_dsn, developer=test_developer, file=test_file): @test("query: list user files") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) - + # Create a file owned by the user file = await create_file( developer_id=developer.id, @@ -127,7 +127,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): owner_id=user.id, connection_pool=pool, ) - + # List user's files files = await list_files( developer_id=developer.id, @@ -142,7 +142,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): @test("query: list agent files") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) - + # Create a file owned by the agent file = await create_file( developer_id=developer.id, @@ -156,7 +156,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): owner_id=agent.id, connection_pool=pool, ) - + # List agent's files files = await list_files( developer_id=developer.id, @@ -171,7 +171,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): @test("query: delete user file") async def _(dsn=pg_dsn, developer=test_developer, user=test_user): pool = await create_db_pool(dsn=dsn) - + # Create a file owned by the user file = await create_file( developer_id=developer.id, @@ -185,7 +185,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): owner_id=user.id, connection_pool=pool, ) - + # Delete the file await delete_file( developer_id=developer.id, @@ -194,7 +194,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): owner_id=user.id, connection_pool=pool, ) - + # Verify file is no longer in user's files files = await list_files( developer_id=developer.id, @@ -208,7 +208,7 @@ async def _(dsn=pg_dsn, developer=test_developer, user=test_user): @test("query: delete agent file") async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) - + # Create a file owned by the agent file = await create_file( developer_id=developer.id, @@ -222,7 +222,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): owner_id=agent.id, connection_pool=pool, ) - + # Delete the file await delete_file( developer_id=developer.id, @@ -231,7 +231,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent): owner_id=agent.id, connection_pool=pool, ) - + # Verify file is no longer in agent's files files = await list_files( developer_id=developer.id, @@ -251,5 +251,3 @@ async def _(dsn=pg_dsn, developer=test_developer, file=test_file): file_id=file.id, connection_pool=pool, ) - - From c88e8d76fe558189194afdb1c1c7fecc592f22af Mon Sep 17 00:00:00 2001 From: vedantsahai18 Date: Thu, 19 Dec 2024 19:10:22 -0500 Subject: [PATCH 6/7] chore: fix conflicts --- agents-api/tests/test_entry_queries.py | 9 +- agents-api/tests/test_session_queries.py | 154 +++++++++++------------ 2 files changed, 81 insertions(+), 82 deletions(-) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 03972cdee..e8286e8bc 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -5,9 +5,9 @@ from uuid import UUID -# from fastapi import HTTPException -# from uuid_extensions import uuid7 -# from ward import raises, test +from fastapi import HTTPException +from uuid_extensions import uuid7 +from ward import raises, test from agents_api.autogen.openapi_model import ( CreateEntryRequest, @@ -23,8 +23,7 @@ ) from tests.fixtures import pg_dsn, test_developer, test_developer_id, test_session -# MODEL = "gpt-4o-mini" - +MODEL = "gpt-4o-mini" @test("query: create entry no session") async def _(dsn=pg_dsn, developer=test_developer): diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 73b232f1f..171e56aa8 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -3,8 +3,8 @@ # Tests verify the SQL queries without actually executing them against a database. # """ -# from uuid_extensions import uuid7 -# from ward import raises, test +from uuid_extensions import uuid7 +from ward import raises, test from agents_api.autogen.openapi_model import ( CreateOrUpdateSessionRequest, @@ -36,11 +36,11 @@ ) -# @test("query: create session sql") -# async def _( -# dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# """Test that a session can be successfully created.""" +@test("query: create session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created.""" pool = await create_db_pool(dsn=dsn) session_id = uuid7() @@ -61,11 +61,11 @@ assert result.id == session_id -# @test("query: create or update session sql") -# async def _( -# dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user -# ): -# """Test that a session can be successfully created or updated.""" +@test("query: create or update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created or updated.""" pool = await create_db_pool(dsn=dsn) session_id = uuid7() @@ -87,39 +87,39 @@ assert result.updated_at is not None -# @test("query: get session exists") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test retrieving an existing session.""" +@test("query: get session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test retrieving an existing session.""" -# pool = await create_db_pool(dsn=dsn) -# result = await get_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) + pool = await create_db_pool(dsn=dsn) + result = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) assert result is not None assert isinstance(result, Session) assert result.id == session.id -# @test("query: get session does not exist") -# async def _(dsn=pg_dsn, developer_id=test_developer_id): -# """Test retrieving a non-existent session.""" +@test("query: get session does not exist") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test retrieving a non-existent session.""" -# session_id = uuid7() -# pool = await create_db_pool(dsn=dsn) -# with raises(Exception): -# await get_session( -# session_id=session_id, -# developer_id=developer_id, -# connection_pool=pool, -# ) + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) + with raises(Exception): + await get_session( + session_id=session_id, + developer_id=developer_id, + connection_pool=pool, + ) -# @test("query: list sessions") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test listing sessions with default pagination.""" +@test("query: list sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with default pagination.""" pool = await create_db_pool(dsn=dsn) result = await list_sessions( @@ -129,14 +129,14 @@ connection_pool=pool, ) -# assert isinstance(result, list) -# assert len(result) >= 1 -# assert any(s.id == session.id for s in result) + assert isinstance(result, list) + assert len(result) >= 1 + assert any(s.id == session.id for s in result) -# @test("query: list sessions with filters") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test listing sessions with specific filters.""" +@test("query: list sessions with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=dsn) result = await list_sessions( @@ -153,15 +153,15 @@ ), f"Result is not a list of sessions, {result}, {session.situation}" -# @test("query: count sessions") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test counting the number of sessions for a developer.""" +@test("query: count sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test counting the number of sessions for a developer.""" -# pool = await create_db_pool(dsn=dsn) -# count = await count_sessions( -# developer_id=developer_id, -# connection_pool=pool, -# ) + pool = await create_db_pool(dsn=dsn) + count = await count_sessions( + developer_id=developer_id, + connection_pool=pool, + ) assert isinstance(count, dict) assert count["count"] >= 1 @@ -190,9 +190,9 @@ async def _( connection_pool=pool, ) -# assert result is not None -# assert isinstance(result, ResourceUpdatedResponse) -# assert result.updated_at > session.created_at + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at updated_session = await get_session( developer_id=developer_id, @@ -202,11 +202,11 @@ async def _( assert updated_session.forward_tool_calls is True -# @test("query: patch session sql") -# async def _( -# dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent -# ): -# """Test that a session can be successfully patched.""" +@test("query: patch session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that a session can be successfully patched.""" pool = await create_db_pool(dsn=dsn) data = PatchSessionRequest( @@ -219,9 +219,9 @@ async def _( connection_pool=pool, ) -# assert result is not None -# assert isinstance(result, ResourceUpdatedResponse) -# assert result.updated_at > session.created_at + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at patched_session = await get_session( developer_id=developer_id, @@ -232,23 +232,23 @@ async def _( assert patched_session.metadata == {"test": "metadata"} -# @test("query: delete session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test that a session can be successfully deleted.""" +@test("query: delete session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test that a session can be successfully deleted.""" -# pool = await create_db_pool(dsn=dsn) -# delete_result = await delete_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) -# assert delete_result is not None -# assert isinstance(delete_result, ResourceDeletedResponse) + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) -# with raises(Exception): -# await get_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) + with raises(Exception): + await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) From 41739ee94dbcfed66dd873db50d628a4810f6a25 Mon Sep 17 00:00:00 2001 From: Vedantsahai18 Date: Fri, 20 Dec 2024 00:11:16 +0000 Subject: [PATCH 7/7] refactor: Lint agents-api (CI) --- agents-api/tests/test_entry_queries.py | 1 + 1 file changed, 1 insertion(+) diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index e8286e8bc..706185c7b 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -25,6 +25,7 @@ MODEL = "gpt-4o-mini" + @test("query: create entry no session") async def _(dsn=pg_dsn, developer=test_developer): """Test the addition of a new entry to the database."""