Skip to content

Commit

Permalink
fix(agents-api): fix sessions and agents queries / tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Dec 19, 2024
1 parent 57e453f commit bbdbb4b
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 167 deletions.
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ async def add_entry_relations(
(
session_exists_query,
[session_id, developer_id],
"fetch",
"fetchrow",
),
(
entry_relation_query,
Expand Down
32 changes: 16 additions & 16 deletions agents-api/agents_api/queries/sessions/create_or_update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@
participant_type,
participant_id
)
SELECT
$1 as developer_id,
$2 as session_id,
unnest($3::participant_type[]) as participant_type,
unnest($4::uuid[]) as participant_id;
VALUES ($1, $2, $3, $4);
""").sql(pretty=True)


Expand All @@ -83,16 +79,23 @@
),
}
)
@wrap_in_class(ResourceUpdatedResponse, one=True)
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {
"id": d["session_id"],
"updated_at": d["updated_at"],
},
)
@increase_counter("create_or_update_session")
@pg_query
@pg_query(return_index=0)
@beartype
async def create_or_update_session(
*,
developer_id: UUID,
session_id: UUID,
data: CreateOrUpdateSessionRequest,
) -> list[tuple[str, list]]:
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""
Constructs SQL queries to create or update a session and its participant lookups.
Expand Down Expand Up @@ -139,14 +142,11 @@ async def create_or_update_session(
]

# Prepare lookup parameters
lookup_params = [
developer_id, # $1
session_id, # $2
participant_types, # $3
participant_ids, # $4
]
lookup_params = []
for participant_type, participant_id in zip(participant_types, participant_ids):
lookup_params.append([developer_id, session_id, participant_type, participant_id])

return [
(session_query, session_params),
(lookup_query, lookup_params),
(session_query, session_params, "fetch"),
(lookup_query, lookup_params, "fetchmany"),
]
23 changes: 15 additions & 8 deletions agents-api/agents_api/queries/sessions/create_session.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from uuid import UUID
from uuid_extensions import uuid7

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import CreateSessionRequest, Session
from ...autogen.openapi_model import CreateSessionRequest, Session, ResourceCreatedResponse
from ...metrics.counters import increase_counter
from ...common.utils.datetime import utcnow
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class

# Define the raw SQL queries
Expand Down Expand Up @@ -63,14 +65,21 @@
),
}
)
@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]})
@wrap_in_class(
Session,
one=True,
transform=lambda d: {
**d,
"id": d["session_id"],
},
)
@increase_counter("create_session")
@pg_query
@pg_query(return_index=0)
@beartype
async def create_session(
*,
developer_id: UUID,
session_id: UUID,
session_id: UUID | None = None,
data: CreateSessionRequest,
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""
Expand All @@ -87,6 +96,7 @@ async def create_session(
# Handle participants
users = data.users or ([data.user] if data.user else [])
agents = data.agents or ([data.agent] if data.agent else [])
session_id = session_id or uuid7()

if not agents:
raise HTTPException(
Expand Down Expand Up @@ -123,10 +133,7 @@ async def create_session(
for ptype, pid in zip(participant_types, participant_ids):
lookup_params.append([developer_id, session_id, ptype, pid])

print("*" * 100)
print(lookup_params)
print("*" * 100)
return [
(session_query, session_params),
(session_query, session_params, "fetch"),
(lookup_query, lookup_params, "fetchmany"),
]
51 changes: 2 additions & 49 deletions agents-api/agents_api/queries/sessions/patch_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,6 @@
SELECT * FROM updated_session;
""").sql(pretty=True)

lookup_query = parse_one("""
WITH deleted_lookups AS (
DELETE FROM session_lookup
WHERE developer_id = $1 AND session_id = $2
)
INSERT INTO session_lookup (
developer_id,
session_id,
participant_type,
participant_id
)
SELECT
$1 as developer_id,
$2 as session_id,
unnest($3::participant_type[]) as participant_type,
unnest($4::uuid[]) as participant_id;
""").sql(pretty=True)


@rewrap_exceptions(
{
asyncpg.ForeignKeyViolationError: partialclass(
Expand All @@ -64,7 +45,7 @@
),
}
)
@wrap_in_class(ResourceUpdatedResponse, one=True)
@wrap_in_class(ResourceUpdatedResponse, one=True, transform=lambda d: {"id": d["session_id"], "updated_at": d["updated_at"]},)
@increase_counter("patch_session")
@pg_query
@beartype
Expand All @@ -85,22 +66,6 @@ async def patch_session(
Returns:
list[tuple[str, list]]: List of SQL queries and their parameters
"""
# Handle participants
users = data.users or ([data.user] if data.user else [])
agents = data.agents or ([data.agent] if data.agent else [])

if data.agent and data.agents:
raise HTTPException(
status_code=400,
detail="Only one of 'agent' or 'agents' should be provided",
)

# Prepare participant arrays for lookup query if participants are provided
participant_types = []
participant_ids = []
if users or agents:
participant_types = ["user"] * len(users) + ["agent"] * len(agents)
participant_ids = [str(u) for u in users] + [str(a) for a in agents]

# Extract fields from data, using None for unset fields
session_params = [
Expand All @@ -116,16 +81,4 @@ async def patch_session(
data.recall_options or {}, # $10
]

queries = [(session_query, session_params)]

# Only add lookup query if participants are provided
if participant_types:
lookup_params = [
developer_id, # $1
session_id, # $2
participant_types, # $3
participant_ids, # $4
]
queries.append((lookup_query, lookup_params))

return queries
return [(session_query, session_params)]
56 changes: 8 additions & 48 deletions agents-api/agents_api/queries/sessions/update_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,6 @@
RETURNING *;
""").sql(pretty=True)

lookup_query = parse_one("""
WITH deleted_lookups AS (
DELETE FROM session_lookup
WHERE developer_id = $1 AND session_id = $2
)
INSERT INTO session_lookup (
developer_id,
session_id,
participant_type,
participant_id
)
SELECT
$1 as developer_id,
$2 as session_id,
unnest($3::participant_type[]) as participant_type,
unnest($4::uuid[]) as participant_id;
""").sql(pretty=True)


@rewrap_exceptions(
{
Expand All @@ -60,7 +42,14 @@
),
}
)
@wrap_in_class(ResourceUpdatedResponse, one=True)
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {
"id": d["session_id"],
"updated_at": d["updated_at"],
},
)
@increase_counter("update_session")
@pg_query
@beartype
Expand All @@ -81,26 +70,6 @@ async def update_session(
Returns:
list[tuple[str, list]]: List of SQL queries and their parameters
"""
# Handle participants
users = data.users or ([data.user] if data.user else [])
agents = data.agents or ([data.agent] if data.agent else [])

if not agents:
raise HTTPException(
status_code=400,
detail="At least one agent must be provided",
)

if data.agent and data.agents:
raise HTTPException(
status_code=400,
detail="Only one of 'agent' or 'agents' should be provided",
)

# Prepare participant arrays for lookup query
participant_types = ["user"] * len(users) + ["agent"] * len(agents)
participant_ids = [str(u) for u in users] + [str(a) for a in agents]

# Prepare session parameters
session_params = [
developer_id, # $1
Expand All @@ -115,15 +84,6 @@ async def update_session(
data.recall_options or {}, # $10
]

# Prepare lookup parameters
lookup_params = [
developer_id, # $1
session_id, # $2
participant_types, # $3
participant_ids, # $4
]

return [
(session_query, session_params),
(lookup_query, lookup_params),
]
17 changes: 9 additions & 8 deletions agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def pg_query(
debug: bool | None = None,
only_on_error: bool = False,
timeit: bool = False,
return_index: int = -1,
) -> Callable[..., Callable[P, list[Record]]] | Callable[P, list[Record]]:
def pg_query_dec(
func: Callable[P, PGQueryArgs | list[PGQueryArgs]],
Expand Down Expand Up @@ -159,6 +160,8 @@ async def wrapper(
async with pool.acquire() as conn:
async with conn.transaction():
start = timeit and time.perf_counter()
all_results = []

for method_name, payload in batch:
method = getattr(conn, method_name)

Expand All @@ -169,11 +172,7 @@ async def wrapper(
results: list[Record] = await method(
query, *args, timeout=timeout
)

print("%" * 100)
print(results)
print(*args)
print("%" * 100)
all_results.append(results)

if method_name == "fetchrow" and (
len(results) == 0 or results.get("bool") is None
Expand Down Expand Up @@ -204,9 +203,11 @@ async def wrapper(

raise

not only_on_error and debug and pprint(results)

return results
# Return results from specified index
results_to_return = all_results[return_index] if all_results else []
not only_on_error and debug and pprint(results_to_return)

return results_to_return

# Set the wrapped function as an attribute of the wrapper,
# forwards the __wrapped__ attribute if it exists.
Expand Down
10 changes: 4 additions & 6 deletions agents-api/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def patch_embed_acompletion():
yield embed, acompletion


@fixture(scope="global")
@fixture(scope="test")
async def test_agent(dsn=pg_dsn, developer=test_developer):
pool = await create_db_pool(dsn=dsn)

Expand All @@ -105,18 +105,16 @@ async def test_agent(dsn=pg_dsn, developer=test_developer):
data=CreateAgentRequest(
model="gpt-4o-mini",
name="test agent",
canonical_name=f"test_agent_{str(int(time.time()))}",
about="test agent about",
metadata={"test": "test"},
),
connection_pool=pool,
)

yield agent
await pool.close()
return agent


@fixture(scope="global")
@fixture(scope="test")
async def test_user(dsn=pg_dsn, developer=test_developer):
pool = await create_db_pool(dsn=dsn)

Expand Down Expand Up @@ -153,7 +151,7 @@ async def test_new_developer(dsn=pg_dsn, email=random_email):
return developer


@fixture(scope="global")
@fixture(scope="test")
async def test_session(
dsn=pg_dsn,
developer_id=test_developer_id,
Expand Down
5 changes: 3 additions & 2 deletions agents-api/tests/test_agent_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id):
)


@test("query: create agent with instructions sql")

@test("query: create or update agent sql")
async def _(dsn=pg_dsn, developer_id=test_developer_id):
"""Test that an agent can be successfully created or updated."""

Expand All @@ -60,6 +61,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id):
)



@test("query: update agent sql")
async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent):
"""Test that an existing agent's information can be successfully updated."""
Expand All @@ -81,7 +83,6 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent):
assert result is not None
assert isinstance(result, ResourceUpdatedResponse)


@test("query: get agent not exists sql")
async def _(dsn=pg_dsn, developer_id=test_developer_id):
"""Test that retrieving a non-existent agent raises an exception."""
Expand Down
Loading

0 comments on commit bbdbb4b

Please sign in to comment.