Skip to content

Commit

Permalink
feat: Add SQL validation
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 20, 2024
1 parent ca12d65 commit 0aecd61
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 43 deletions.
9 changes: 9 additions & 0 deletions agents-api/agents_api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,12 @@ class FailedEncodingSentinel:
"""Sentinel object returned when failed to encode payload."""

payload_data: bytes


class QueriesBaseException(AgentsBaseException):
pass


class InvalidSQLQuery(QueriesBaseException):
def __init__(self, query_name: str):
super().__init__(f"invalid query: {query_name}")
90 changes: 47 additions & 43 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Any, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype

from ...common.protocol.sessions import ChatContext, make_session
from ...exceptions import InvalidSQLQuery
from ..utils import (
pg_query,
wrap_in_class,
Expand All @@ -13,19 +15,19 @@
T = TypeVar("T")


query = """
SELECT * FROM
sql_query = sqlvalidator.parse(
"""SELECT * FROM
(
SELECT jsonb_agg(u) AS users FROM (
SELECT
session_lookup.participant_id,
users.user_id AS id,
users.developer_id,
users.name,
users.about,
users.created_at,
users.updated_at,
users.metadata
users.developer_id,
users.name,
users.about,
users.created_at,
users.updated_at,
users.metadata
FROM session_lookup
INNER JOIN users ON session_lookup.participant_id = users.user_id
WHERE
Expand All @@ -39,16 +41,16 @@
SELECT
session_lookup.participant_id,
agents.agent_id AS id,
agents.developer_id,
agents.canonical_name,
agents.name,
agents.about,
agents.instructions,
agents.model,
agents.created_at,
agents.updated_at,
agents.metadata,
agents.default_settings
agents.developer_id,
agents.canonical_name,
agents.name,
agents.about,
agents.instructions,
agents.model,
agents.created_at,
agents.updated_at,
agents.metadata,
agents.default_settings
FROM session_lookup
INNER JOIN agents ON session_lookup.participant_id = agents.agent_id
WHERE
Expand All @@ -58,50 +60,52 @@
) a
) AS agents,
(
SELECT to_jsonb(s) AS session FROM (
SELECT to_jsonb(s) AS session FROM (
SELECT
sessions.session_id AS id,
sessions.developer_id,
sessions.situation,
sessions.system_template,
sessions.created_at,
sessions.metadata,
sessions.render_templates,
sessions.token_budget,
sessions.context_overflow,
sessions.forward_tool_calls,
sessions.recall_options
sessions.developer_id,
sessions.situation,
sessions.system_template,
sessions.created_at,
sessions.metadata,
sessions.render_templates,
sessions.token_budget,
sessions.context_overflow,
sessions.forward_tool_calls,
sessions.recall_options
FROM sessions
WHERE
developer_id = $1 AND
session_id = $2
LIMIT 1
LIMIT 1
) s
) AS session,
(
SELECT jsonb_agg(r) AS toolsets FROM (
SELECT
session_lookup.participant_id,
tools.tool_id as id,
tools.developer_id,
tools.agent_id,
tools.task_id,
tools.task_version,
tools.type,
tools.name,
tools.description,
tools.spec,
tools.updated_at,
tools.created_at
tools.developer_id,
tools.agent_id,
tools.task_id,
tools.task_version,
tools.type,
tools.name,
tools.description,
tools.spec,
tools.updated_at,
tools.created_at
FROM session_lookup
INNER JOIN tools ON session_lookup.participant_id = tools.agent_id
WHERE
session_lookup.developer_id = $1 AND
session_id = $2 AND
session_lookup.participant_type = 'agent'
) r
) AS toolsets
"""
) AS toolsets"""
)
if not sql_query.is_valid():
raise InvalidSQLQuery("prepare_chat_context")


def _transform(d):
Expand Down Expand Up @@ -160,6 +164,6 @@ async def prepare_chat_context(
"""

return (
[query],
[sql_query.format()],
[developer_id, session_id],
)

0 comments on commit 0aecd61

Please sign in to comment.