Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Dec 18, 2024
1 parent 46a3b39 commit 1b7a022
Show file tree
Hide file tree
Showing 21 changed files with 439 additions and 213 deletions.
40 changes: 40 additions & 0 deletions agents-api/agents_api/autogen/Sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel):
"""
A specific situation that sets the background for this session
"""
system_template: str | None = None
"""
System prompt for this session
"""
render_templates: StrictBool = True
"""
Render system and assistant message content as jinja templates
Expand All @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel):
If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
"""
forward_tool_calls: StrictBool = False
"""
Whether to forward tool calls to the model
"""
recall_options: RecallOptions | None = None
metadata: dict[str, Any] | None = None

Expand All @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel):
"""
A specific situation that sets the background for this session
"""
system_template: str | None = None
"""
System prompt for this session
"""
render_templates: StrictBool = True
"""
Render system and assistant message content as jinja templates
Expand All @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel):
If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
"""
forward_tool_calls: StrictBool = False
"""
Whether to forward tool calls to the model
"""
recall_options: RecallOptionsUpdate | None = None
metadata: dict[str, Any] | None = None

Expand Down Expand Up @@ -121,6 +137,10 @@ class Session(BaseModel):
"""
A specific situation that sets the background for this session
"""
system_template: str | None = None
"""
System prompt for this session
"""
summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None
"""
Summary (null at the beginning) - generated automatically after every interaction
Expand All @@ -145,6 +165,10 @@ class Session(BaseModel):
If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
"""
forward_tool_calls: StrictBool = False
"""
Whether to forward tool calls to the model
"""
recall_options: RecallOptions | None = None
id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})]
metadata: dict[str, Any] | None = None
Expand Down Expand Up @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel):
"""
A specific situation that sets the background for this session
"""
system_template: str | None = None
"""
System prompt for this session
"""
render_templates: StrictBool = True
"""
Render system and assistant message content as jinja templates
Expand All @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel):
If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
"""
forward_tool_calls: StrictBool = False
"""
Whether to forward tool calls to the model
"""
recall_options: RecallOptions | None = None
metadata: dict[str, Any] | None = None

Expand All @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest):
"""
A specific situation that sets the background for this session
"""
system_template: str | None = None
"""
System prompt for this session
"""
render_templates: StrictBool = True
"""
Render system and assistant message content as jinja templates
Expand All @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest):
If a tool call is made, the tool's output will be sent back to the model as the model's input.
If a tool call is not made, the model's output will be returned as is.
"""
forward_tool_calls: StrictBool = False
"""
Whether to forward tool calls to the model
"""
recall_options: RecallOptions | None = None
metadata: dict[str, Any] | None = None

Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/queries/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from pydantic import ValidationError
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...metrics.counters import increase_counter

from ...autogen.openapi_model import Agent, CreateAgentRequest
from ..utils import (
generate_canonical_name,
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest
from ...metrics.counters import increase_counter
from ..utils import (
generate_canonical_name,
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
7 changes: 3 additions & 4 deletions agents-api/agents_api/queries/agents/delete_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import ResourceDeletedResponse
from ...metrics.counters import increase_counter
from ...common.utils.datetime import utcnow
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
5 changes: 1 addition & 4 deletions agents-api/agents_api/queries/agents/get_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import Agent
from ...metrics.counters import increase_counter
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

raw_query = """
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/queries/agents/list_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import Agent
from ...metrics.counters import increase_counter
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
5 changes: 1 addition & 4 deletions agents-api/agents_api/queries/agents/patch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
from ...metrics.counters import increase_counter
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
5 changes: 1 addition & 4 deletions agents-api/agents_api/queries/agents/update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from uuid import UUID

from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from sqlglot.optimizer import optimize

from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
from ...metrics.counters import increase_counter
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

ModelT = TypeVar("ModelT", bound=Any)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/developers/get_developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ..utils import (
partialclass,
pg_query,
rewrap_exceptions,
wrap_in_class,
rewrap_exceptions,
)

# TODO: Add verify_developer
Expand Down
18 changes: 9 additions & 9 deletions agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@

# Query for checking if the session exists
session_exists_query = """
SELECT CASE
WHEN EXISTS (
SELECT 1 FROM sessions
WHERE session_id = $1 AND developer_id = $2
)
THEN TRUE
ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error
END;
SELECT EXISTS (
SELECT 1 FROM sessions
WHERE session_id = $1 AND developer_id = $2
) AS exists;
"""

# Define the raw SQL query for creating entries
Expand Down Expand Up @@ -71,6 +67,10 @@
status_code=400,
detail=str(exc),
),
asyncpg.NoDataFoundError: lambda exc: HTTPException(
status_code=404,
detail="Session not found",
),
}
)
@wrap_in_class(
Expand Down Expand Up @@ -166,7 +166,7 @@ async def add_entry_relations(
item.get("is_leaf", False), # $5
]
)

return [
(
session_exists_query,
Expand Down
10 changes: 7 additions & 3 deletions agents-api/agents_api/queries/entries/list_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
status_code=400,
detail=str(exc),
),
asyncpg.NoDataFoundError: lambda exc: HTTPException(
status_code=404,
detail="Session not found",
),
}
)
@wrap_in_class(Entry)
Expand All @@ -78,7 +82,7 @@ async def list_entries(
sort_by: Literal["created_at", "timestamp"] = "timestamp",
direction: Literal["asc", "desc"] = "asc",
exclude_relations: list[str] = [],
) -> list[tuple[str, list]]:
) -> list[tuple[str, list] | tuple[str, list, str]]:
if limit < 1 or limit > 1000:
raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000")
if offset < 0:
Expand All @@ -98,14 +102,14 @@ async def list_entries(
developer_id, # $5
exclude_relations, # $6
]

return [
(
session_exists_query,
[session_id, developer_id],
"fetchrow",
),
(
query,
entry_params,
entry_params
),
]
28 changes: 12 additions & 16 deletions agents-api/agents_api/queries/sessions/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,7 @@
participant_type,
participant_id
)
SELECT
$1 as developer_id,
$2 as session_id,
unnest($3::participant_type[]) as participant_type,
unnest($4::uuid[]) as participant_id;
VALUES ($1, $2, $3, $4);
""").sql(pretty=True)


Expand All @@ -67,7 +63,7 @@
),
}
)
@wrap_in_class(Session, one=True, transform=lambda d: {**d, "id": d["session_id"]})
@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]})
@increase_counter("create_session")
@pg_query
@beartype
Expand All @@ -76,7 +72,7 @@ async def create_session(
developer_id: UUID,
session_id: UUID,
data: CreateSessionRequest,
) -> list[tuple[str, list]]:
) -> list[tuple[str, list] | tuple[str, list, str]]:
"""
Constructs SQL queries to create a new session and its participant lookups.
Expand All @@ -86,7 +82,7 @@ async def create_session(
data (CreateSessionRequest): Session creation data
Returns:
list[tuple[str, list]]: SQL queries and their parameters
list[tuple[str, list] | tuple[str, list, str]]: SQL queries and their parameters
"""
# Handle participants
users = data.users or ([data.user] if data.user else [])
Expand Down Expand Up @@ -122,15 +118,15 @@ async def create_session(
data.recall_options or {}, # $10
]

# Prepare lookup parameters
lookup_params = [
developer_id, # $1
session_id, # $2
participant_types, # $3
participant_ids, # $4
]
# Prepare lookup parameters as a list of parameter lists
lookup_params = []
for ptype, pid in zip(participant_types, participant_ids):
lookup_params.append([developer_id, session_id, ptype, pid])

print("*" * 100)
print(lookup_params)
print("*" * 100)
return [
(session_query, session_params),
(lookup_query, lookup_params),
(lookup_query, lookup_params, "fetchmany"),
]
Loading

0 comments on commit 1b7a022

Please sign in to comment.