Skip to content

Commit

Permalink
feat: Add developer queries, create db connection pool
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 17, 2024
1 parent 6c37070 commit 8901a59
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 53 deletions.
35 changes: 35 additions & 0 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
import asyncpg
from contextlib import asynccontextmanager
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from .env import api_prefix, db_dsn
from .clients.pg import create_db_pool


@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.postgres_pool = await create_db_pool(db_dsn)
yield
await app.state.postgres_pool.close()


app: FastAPI = FastAPI(
docs_url="/swagger",
openapi_prefix=api_prefix,
redoc_url=None,
title="Julep Agents API",
description="API for Julep Agents",
version="0.4.0",
terms_of_service="https://www.julep.ai/terms",
contact={
"name": "Julep",
"url": "https://www.julep.ai",
"email": "[email protected]",
},
root_path=api_prefix,
lifespan=lifespan,
)

# Enable metrics
Instrumentator().instrument(app).expose(app, include_in_schema=False)
24 changes: 9 additions & 15 deletions agents-api/agents_api/clients/pg.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import json

import asyncpg

from ..env import db_dsn
from ..web import app

async def _init_conn(conn):
await conn.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)

async def get_pg_client():
# TODO: Create a postgres connection pool
client = getattr(app.state, "pg_client", await asyncpg.connect(db_dsn))
if not hasattr(app.state, "pg_client"):
await client.set_type_codec(
"jsonb",
encoder=json.dumps,
decoder=json.loads,
schema="pg_catalog",
)
app.state.pg_client = client

return client
async def create_db_pool(dsn: str):
return await asyncpg.create_pool(dsn, init=_init_conn)
4 changes: 1 addition & 3 deletions agents-api/agents_api/dependencies/developer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..common.protocol.developers import Developer
from ..env import multi_tenant_mode
from ..queries.developer.get_developer import get_developer, verify_developer
from ..queries.developer.get_developer import get_developer
from .exceptions import InvalidHeaderFormat


Expand All @@ -24,8 +24,6 @@ async def get_developer_id(
except ValueError as e:
raise InvalidHeaderFormat("X-Developer-Id must be a valid UUID") from e

verify_developer(developer_id=x_developer_id)

return x_developer_id


Expand Down
5 changes: 4 additions & 1 deletion agents-api/agents_api/queries/developer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,7 @@

# ruff: noqa: F401, F403, F405

from .get_developer import get_developer, verify_developer
from .get_developer import get_developer
from .create_developer import create_developer
from .update_developer import update_developer
from .patch_developer import patch_developer
54 changes: 54 additions & 0 deletions agents-api/agents_api/queries/developer/create_developer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...common.protocol.developers import Developer
from ..utils import (
pg_query,
wrap_in_class,
)

query = parse_one("""
INSERT INTO developers (
developer_id,
email,
active,
tags,
settings
)
VALUES (
$1,
$2,
$3,
$4,
$5::jsonb
)
RETURNING *;
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=403),
# ValidationError: partialclass(HTTPException, status_code=500),
# }
# )
@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]})
@pg_query
@beartype
async def create_developer(
*,
email: str,
active: bool = True,
tags: list[str] | None = None,
settings: dict | None = None,
developer_id: UUID | None = None,
) -> tuple[str, list]:
developer_id = str(developer_id or uuid7())

return (
query,
[developer_id, email, active, tags or [], settings or {}],
)
42 changes: 42 additions & 0 deletions agents-api/agents_api/queries/developer/patch_developer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ...common.protocol.developers import Developer
from ..utils import (
pg_query,
wrap_in_class,
)

query = parse_one("""
UPDATE developers
SET email = $1, active = $2, tags = tags || $3, settings = settings || $4
WHERE developer_id = $5
RETURNING *;
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=403),
# ValidationError: partialclass(HTTPException, status_code=500),
# }
# )
@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]})
@pg_query
@beartype
async def patch_developer(
*,
developer_id: UUID,
email: str,
active: bool = True,
tags: list[str] | None = None,
settings: dict | None = None,
) -> tuple[str, list]:
developer_id = str(developer_id)

return (
query,
[email, active, tags or [], settings or {}, developer_id],
)
42 changes: 42 additions & 0 deletions agents-api/agents_api/queries/developer/update_developer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from uuid import UUID

from beartype import beartype
from sqlglot import parse_one

from ...common.protocol.developers import Developer
from ..utils import (
pg_query,
wrap_in_class,
)

query = parse_one("""
UPDATE developers
SET email = $1, active = $2, tags = $3, settings = $4
WHERE developer_id = $5
RETURNING *;
""").sql(pretty=True)


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=403),
# ValidationError: partialclass(HTTPException, status_code=500),
# }
# )
@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]})
@pg_query
@beartype
async def update_developer(
*,
developer_id: UUID,
email: str,
active: bool = True,
tags: list[str] | None = None,
settings: dict | None = None,
) -> tuple[str, list]:
developer_id = str(developer_id)

return (
query,
[email, active, tags or [], settings or {}, developer_id],
)
34 changes: 18 additions & 16 deletions agents-api/agents_api/queries/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import concurrent.futures
import inspect
import socket
import asyncpg
import time
from functools import partialmethod, wraps
from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar
from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeVar, cast

import pandas as pd
from asyncpg import Record
from fastapi import HTTPException
from pydantic import BaseModel
from ..app import app

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -44,12 +46,12 @@ def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):

from pprint import pprint

from tenacity import (
retry,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)
# from tenacity import (
# retry,
# retry_if_exception,
# stop_after_attempt,
# wait_exponential,
# )

# TODO: Remove all tenacity decorators
# @retry(
Expand All @@ -59,7 +61,7 @@ def pg_query_dec(func: Callable[P, tuple[str | list[Any], dict]]):
# )
@wraps(func)
async def wrapper(
*args: P.args, client=None, **kwargs: P.kwargs
*args: P.args, connection_pool: asyncpg.Pool | None =None, **kwargs: P.kwargs
) -> list[Record]:
query, variables = await func(*args, **kwargs)

Expand All @@ -71,16 +73,16 @@ async def wrapper(
)

# Run the query
from ..clients import pg

try:
client = client or await pg.get_pg_client()

start = timeit and time.perf_counter()
results: list[Record] = await client.fetch(query, *variables)
end = timeit and time.perf_counter()

timeit and print(f"PostgreSQL query time: {end - start:.2f} seconds")
pool = connection_pool if connection_pool is not None else cast(asyncpg.Pool, app.state.postgres_pool)
async with pool.acquire() as conn:
async with conn.transaction():
start = timeit and time.perf_counter()
results: list[Record] = await conn.fetch(query, *variables)
end = timeit and time.perf_counter()

timeit and print(f"PostgreSQL query time: {end - start:.2f} seconds")

except Exception as e:
if only_on_error and debug:
Expand Down
19 changes: 1 addition & 18 deletions agents-api/agents_api/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from litellm.exceptions import APIError
from prometheus_fastapi_instrumentator import Instrumentator
from pycozo.client import QueryException
from pydantic import ValidationError
from scalar_fastapi import get_scalar_api_reference
from temporalio.service import RPCError

from .app import app
from .common.exceptions import BaseCommonException
from .dependencies.auth import get_api_key
from .env import api_prefix, hostname, protocol, public_port, sentry_dsn
Expand Down Expand Up @@ -144,24 +144,7 @@ def register_exceptions(app: FastAPI) -> None:
# Because some routes don't require auth
# See: https://fastapi.tiangolo.com/tutorial/bigger-applications/
#
app: FastAPI = FastAPI(
docs_url="/swagger",
openapi_prefix=api_prefix,
redoc_url=None,
title="Julep Agents API",
description="API for Julep Agents",
version="0.4.0",
terms_of_service="https://www.julep.ai/terms",
contact={
"name": "Julep",
"url": "https://www.julep.ai",
"email": "[email protected]",
},
root_path=api_prefix,
)

# Enable metrics
Instrumentator().instrument(app).expose(app, include_in_schema=False)

# Create a new router for the docs
scalar_router = APIRouter()
Expand Down

0 comments on commit 8901a59

Please sign in to comment.