Skip to content

Commit

Permalink
fix: Miscellaneous fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Diwank Singh Tomer <[email protected]>
  • Loading branch information
creatorrr committed Dec 21, 2024
1 parent 2900786 commit c2d54a4
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 123 deletions.
6 changes: 1 addition & 5 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,11 @@ class Doc(BaseModel):
"""
Language of the document
"""
index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Index of the document
"""
embedding_model: Annotated[
str | None, Field(json_schema_extra={"readOnly": True})
] = None
"""
Embedding model to use for the document
Embedding model used for the document
"""
embedding_dimensions: Annotated[
int | None, Field(json_schema_extra={"readOnly": True})
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/queries/developers/get_developer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Module for retrieving document snippets from the CozoDB based on document IDs."""

from typing import Any, TypeVar
from uuid import UUID

import asyncpg
Expand Down
63 changes: 14 additions & 49 deletions agents-api/agents_api/queries/docs/create_doc.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import ast
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateDocRequest, Doc
from ...autogen.openapi_model import CreateDocRequest, ResourceCreatedResponse
from ...common.utils.datetime import utcnow
from ...metrics.counters import increase_counter
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Base INSERT for docs
doc_query = parse_one("""
doc_query = """
INSERT INTO docs (
developer_id,
doc_id,
Expand All @@ -38,48 +37,15 @@
$9, -- language
$10 -- metadata (JSONB)
)
RETURNING *;
""").sql(pretty=True)
"""

# Owner association query for doc_owners
doc_owner_query = parse_one("""
WITH inserted_owner AS (
INSERT INTO doc_owners (
developer_id,
doc_id,
index,
owner_type,
owner_id
)
VALUES ($1, $2, $3, $4, $5)
RETURNING doc_id
)
SELECT DISTINCT ON (docs.doc_id)
docs.doc_id,
docs.developer_id,
docs.title,
array_agg(docs.content ORDER BY docs.index) as content,
array_agg(docs.index ORDER BY docs.index) as indices,
docs.modality,
docs.embedding_model,
docs.embedding_dimensions,
docs.language,
docs.metadata,
docs.created_at
FROM inserted_owner io
JOIN docs ON docs.doc_id = io.doc_id
GROUP BY
docs.doc_id,
docs.developer_id,
docs.title,
docs.modality,
docs.embedding_model,
docs.embedding_dimensions,
docs.language,
docs.metadata,
docs.created_at;
""").sql(pretty=True)
doc_owner_query = """
INSERT INTO doc_owners (developer_id, doc_id, owner_type, owner_id)
VALUES ($1, $2, $3, $4)
ON CONFLICT DO NOTHING
RETURNING *;
"""


@rewrap_exceptions(
Expand All @@ -102,12 +68,12 @@
}
)
@wrap_in_class(
Doc,
ResourceCreatedResponse,
one=True,
transform=lambda d: {
"id": d["doc_id"],
"index": d["indices"][0],
"content": d["content"][0] if len(d["content"]) == 1 else d["content"],
"jobs": [],
"created_at": utcnow(),
**d,
},
)
Expand Down Expand Up @@ -146,6 +112,7 @@ async def create_doc(
list[tuple[str, list] | tuple[str, list, str]]: SQL query and parameters for creating the document.
"""
queries = []

# Generate a UUID if not provided
current_doc_id = uuid7() if doc_id is None else doc_id

Expand All @@ -172,7 +139,6 @@ async def create_doc(
owner_params = [
developer_id,
current_doc_id,
idx,
owner_type,
owner_id,
]
Expand Down Expand Up @@ -202,7 +168,6 @@ async def create_doc(
owner_params = [
developer_id,
current_doc_id,
index,
owner_type,
owner_id,
]
Expand Down
32 changes: 19 additions & 13 deletions agents-api/agents_api/queries/docs/get_doc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import ast
from typing import Literal
from uuid import UUID

from beartype import beartype
Expand All @@ -11,7 +9,7 @@
# Update the query to use DISTINCT ON to prevent duplicates
doc_with_embedding_query = parse_one("""
WITH doc_data AS (
SELECT DISTINCT ON (d.doc_id)
SELECT
d.doc_id,
d.developer_id,
d.title,
Expand Down Expand Up @@ -44,18 +42,26 @@
""").sql(pretty=True)


def transform_get_doc(d: dict) -> dict:
content = d["content"][0] if len(d["content"]) == 1 else d["content"]

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]
if embeddings and all((e is None) for e in embeddings):
embeddings = None

transformed = {
**d,
"id": d["doc_id"],
"content": content,
"embeddings": embeddings,
}
return transformed


@wrap_in_class(
Doc,
one=True, # Changed to True since we're now returning one grouped record
transform=lambda d: {
"id": d["doc_id"],
"index": d["indices"][0],
"content": d["content"][0] if len(d["content"]) == 1 else d["content"],
"embeddings": d["embeddings"][0]
if len(d["embeddings"]) == 1
else d["embeddings"],
**d,
},
one=True,
transform=transform_get_doc,
)
@pg_query
@beartype
Expand Down
28 changes: 18 additions & 10 deletions agents-api/agents_api/queries/docs/list_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Base query for listing docs with aggregated content and embeddings
base_docs_query = parse_one("""
WITH doc_data AS (
SELECT DISTINCT ON (d.doc_id)
SELECT
d.doc_id,
d.developer_id,
d.title,
Expand Down Expand Up @@ -54,6 +54,22 @@
""").sql(pretty=True)


def transform_list_docs(d: dict) -> dict:
content = d["content"][0] if len(d["content"]) == 1 else d["content"]

embeddings = d["embeddings"][0] if len(d["embeddings"]) == 1 else d["embeddings"]
if embeddings and all((e is None) for e in embeddings):
embeddings = None

transformed = {
**d,
"id": d["doc_id"],
"content": content,
"embeddings": embeddings,
}
return transformed


@rewrap_exceptions(
{
asyncpg.NoDataFoundError: partialclass(
Expand All @@ -71,15 +87,7 @@
@wrap_in_class(
Doc,
one=False,
transform=lambda d: {
"id": d["doc_id"],
"index": d["indices"][0],
"content": d["content"][0] if len(d["content"]) == 1 else d["content"],
"embedding": d["embeddings"][0]
if d.get("embeddings") and len(d["embeddings"]) == 1
else d.get("embeddings"),
**d,
},
transform=transform_list_docs,
)
@pg_query
@beartype
Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/queries/docs/search_docs_by_text.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import json
from typing import Any, List, Literal
from typing import Any, Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import DocReference
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/files/create_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import base64
import hashlib
from typing import Any, Literal
from typing import Literal
from uuid import UUID

import asyncpg
Expand Down
20 changes: 10 additions & 10 deletions agents-api/tests/test_docs_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@
@test("query: create user doc")
async def _(dsn=pg_dsn, developer=test_developer, user=test_user):
pool = await create_db_pool(dsn=dsn)
doc = await create_doc(
doc_created = await create_doc(
developer_id=developer.id,
data=CreateDocRequest(
title="User Doc",
content="Docs for user testing",
content=["Docs for user testing", "Docs for user testing 2"],
metadata={"test": "test"},
embed_instruction="Embed the document",
),
owner_type="user",
owner_id=user.id,
connection_pool=pool,
)
assert doc.title == "User Doc"

assert doc_created.id is not None

# Verify doc appears in user's docs
docs_list = await list_docs(
found = await get_doc(
developer_id=developer.id,
owner_type="user",
owner_id=user.id,
doc_id=doc_created.id,
connection_pool=pool,
)
assert any(d.id == doc.id for d in docs_list)
assert found.id == doc_created.id


@test("query: create agent doc")
Expand All @@ -58,7 +58,7 @@ async def _(dsn=pg_dsn, developer=test_developer, agent=test_agent):
owner_id=agent.id,
connection_pool=pool,
)
assert doc.title == "Agent Doc"
assert doc.id is not None

# Verify doc appears in agent's docs
docs_list = await list_docs(
Expand All @@ -79,8 +79,8 @@ async def _(dsn=pg_dsn, developer=test_developer, doc=test_doc):
connection_pool=pool,
)
assert doc_test.id == doc.id
assert doc_test.title == doc.title
assert doc_test.content == doc.content
assert doc_test.title is not None
assert doc_test.content is not None


@test("query: list user docs")
Expand Down
6 changes: 1 addition & 5 deletions integrations-service/integrations/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,11 @@ class Doc(BaseModel):
"""
Language of the document
"""
index: Annotated[int | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Index of the document
"""
embedding_model: Annotated[
str | None, Field(json_schema_extra={"readOnly": True})
] = None
"""
Embedding model to use for the document
Embedding model used for the document
"""
embedding_dimensions: Annotated[
int | None, Field(json_schema_extra={"readOnly": True})
Expand Down
25 changes: 12 additions & 13 deletions memory-store/migrations/000006_docs.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ CREATE TABLE IF NOT EXISTS docs (
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id),
CONSTRAINT pk_docs PRIMARY KEY (developer_id, doc_id, index),
CONSTRAINT ct_docs_embedding_dimensions_positive CHECK (embedding_dimensions > 0),
CONSTRAINT ct_docs_valid_modality CHECK (modality IN ('text', 'image', 'mixed')),
CONSTRAINT ct_docs_index_positive CHECK (index >= 0),
CONSTRAINT ct_docs_valid_language CHECK (is_valid_language (language)),
UNIQUE (developer_id, doc_id, index)
CONSTRAINT ct_docs_valid_language CHECK (is_valid_language (language))
);

-- Create foreign key constraint if not exists (using DO block for safety)
Expand Down Expand Up @@ -62,20 +61,20 @@ END $$;
CREATE TABLE IF NOT EXISTS doc_owners (
developer_id UUID NOT NULL,
doc_id UUID NOT NULL,
owner_type TEXT NOT NULL, -- 'user' or 'agent'
owner_type TEXT NOT NULL, -- 'user' or 'agent'
owner_id UUID NOT NULL,
CONSTRAINT pk_doc_owners PRIMARY KEY (developer_id, doc_id),
CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id),
-- TODO: Ensure that doc exists (this constraint is not working)
-- CONSTRAINT fk_doc_owners_doc FOREIGN KEY (developer_id, doc_id) REFERENCES docs (developer_id, doc_id),
CONSTRAINT ct_doc_owners_owner_type CHECK (owner_type IN ('user', 'agent'))
);

-- Create indexes
CREATE INDEX IF NOT EXISTS idx_doc_owners_owner
ON doc_owners (developer_id, owner_type, owner_id);
CREATE INDEX IF NOT EXISTS idx_doc_owners_owner ON doc_owners (developer_id, owner_type, owner_id);

-- Create function to validate owner reference
CREATE OR REPLACE FUNCTION validate_doc_owner()
RETURNS TRIGGER AS $$
CREATE
OR REPLACE FUNCTION validate_doc_owner () RETURNS TRIGGER AS $$
BEGIN
IF NEW.owner_type = 'user' THEN
IF NOT EXISTS (
Expand All @@ -97,10 +96,10 @@ END;
$$ LANGUAGE plpgsql;

-- Create trigger for validation
CREATE TRIGGER trg_validate_doc_owner
BEFORE INSERT OR UPDATE ON doc_owners
FOR EACH ROW
EXECUTE FUNCTION validate_doc_owner();
CREATE TRIGGER trg_validate_doc_owner BEFORE INSERT
OR
UPDATE ON doc_owners FOR EACH ROW
EXECUTE FUNCTION validate_doc_owner ();

-- Create indexes if not exists
CREATE INDEX IF NOT EXISTS idx_docs_metadata ON docs USING GIN (metadata);
Expand Down
Loading

0 comments on commit c2d54a4

Please sign in to comment.