Skip to content

Commit

Permalink
Merge branch 'f/migrate-queries-to-pg' into f/add-user-queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Vedantsahai18 committed Dec 16, 2024
2 parents 55500d9 + 1907787 commit ceca67a
Show file tree
Hide file tree
Showing 10 changed files with 552 additions and 118 deletions.
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..common.storage_handler import auto_blob_store, load_from_blob_store_if_remote
from ..env import testing
from ..models.developer import get_developer
from ..queries.developer import get_developer
from .utils import get_handler

# For running synchronous code in the background
Expand Down Expand Up @@ -94,7 +94,7 @@ async def execute_system(

# Handle chat operations
if system.operation == "chat" and system.resource == "session":
developer = get_developer(developer_id=arguments.get("developer_id"))
developer = await get_developer(developer_id=arguments.get("developer_id"))
session_id = arguments.get("session_id")
x_custom_api_key = arguments.get("x_custom_api_key", None)
chat_input = ChatInput(**arguments)
Expand Down
12 changes: 12 additions & 0 deletions agents-api/agents_api/clients/pg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncpg

from ..env import db_dsn
from ..web import app


async def get_pg_client():
client = getattr(app.state, "pg_client", await asyncpg.connect(db_dsn))
if not hasattr(app.state, "pg_client"):
app.state.pg_client = client

return client
8 changes: 5 additions & 3 deletions agents-api/agents_api/dependencies/developer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..common.protocol.developers import Developer
from ..env import multi_tenant_mode
from ..models.developer.get_developer import get_developer, verify_developer
from ..queries.developer.get_developer import get_developer, verify_developer
from .exceptions import InvalidHeaderFormat


Expand Down Expand Up @@ -36,7 +36,9 @@ async def get_developer_data(
assert (
not x_developer_id
), "X-Developer-Id header not allowed in multi-tenant mode"
return get_developer(developer_id=UUID("00000000-0000-0000-0000-000000000000"))
return await get_developer(
developer_id=UUID("00000000-0000-0000-0000-000000000000")
)

if not x_developer_id:
raise InvalidHeaderFormat("X-Developer-Id header required")
Expand All @@ -47,6 +49,6 @@ async def get_developer_data(
except ValueError as e:
raise InvalidHeaderFormat("X-Developer-Id must be a valid UUID") from e

developer = get_developer(developer_id=x_developer_id)
developer = await get_developer(developer_id=x_developer_id)

return developer
7 changes: 7 additions & 0 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@
"DO_VERIFY_DEVELOPER_OWNS_RESOURCE", default=True
)

# PostgreSQL
# ----
db_dsn: str = env.str(
"DB_DSN",
default="postgres://postgres:[email protected]:5432/postgres?sslmode=disable",
)


# Auth
# ----
Expand Down
9 changes: 5 additions & 4 deletions agents-api/agents_api/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from uuid import UUID

import pandas as pd
from asyncpg import Record
from fastapi import HTTPException
from httpcore import ConnectError, NetworkError, TimeoutException
from httpx import ConnectError as HttpxConnectError
Expand Down Expand Up @@ -463,12 +464,12 @@ def wrap_in_class(
transform: Callable[[dict], dict] | None = None,
_kind: str | None = None,
):
def _return_data(df: pd.DataFrame):
def _return_data(rec: Record):
# Convert df to list of dicts
if _kind:
df = df[df["_kind"] == _kind]
# if _kind:
# rec = rec[rec["_kind"] == _kind]

data = df.to_dict(orient="records")
data = list(rec.items())

nonlocal transform
transform = transform or (lambda x: x)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError
from sqlglot import parse_one

from ...common.protocol.developers import Developer
from ..utils import (
cozo_query,
partialclass,
pg_query,
rewrap_exceptions,
verify_developer_id_query,
wrap_in_class,
)

query = parse_one("SELECT * FROM developers WHERE developer_id = $1").sql(pretty=True)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")

Expand All @@ -38,37 +42,15 @@ def verify_developer(
}
)
@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]})
@cozo_query
@pg_query
@beartype
def get_developer(
async def get_developer(
*,
developer_id: UUID,
) -> tuple[str, dict]:
) -> tuple[str, list]:
developer_id = str(developer_id)

query = """
input[developer_id] <- [[to_uuid($developer_id)]]
?[
developer_id,
email,
active,
tags,
settings,
created_at,
updated_at,
] :=
input[developer_id],
*developers {
developer_id,
email,
active,
tags,
settings,
created_at,
updated_at,
}
:limit 1
"""

return (query, {"developer_id": developer_id})
return (
query,
[developer_id],
)
Loading

0 comments on commit ceca67a

Please sign in to comment.