diff --git a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py index a8a9dba1a..57453cd34 100644 --- a/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py +++ b/agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py @@ -1,93 +1,62 @@ from typing import Literal from uuid import UUID +import sqlvalidator from beartype import beartype +from ...exceptions import InvalidSQLQuery from ..utils import ( pg_query, wrap_in_class, ) +tools_args_for_task_query = sqlvalidator.parse( + """SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM tasks + WHERE task_id = $2 AND developer_id = $4 LIMIT 1 +) AS tasks_md""" +) -def tool_args_for_task( - *, - developer_id: UUID, - agent_id: UUID, - task_id: UUID, - tool_type: Literal["integration", "api_call"] = "integration", - arg_type: Literal["args", "setup"] = "args", -) -> tuple[list[str], dict]: - agent_id = str(agent_id) - task_id = str(task_id) - - get_query = f""" - input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]] - - ?[values] := - input[agent_id, task_id], - *tasks {{ - task_id, - metadata: task_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - task_{arg_type} = get(task_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, task_{arg_type}), - - :limit 1 - """ - - queries = [ - get_query, - ] - - return (queries, {"agent_id": agent_id, "task_id": task_id}) - - -def tool_args_for_session( - *, - developer_id: UUID, - session_id: UUID, - agent_id: UUID, - arg_type: Literal["args", "setup"] = "args", - tool_type: Literal["integration", "api_call"] = "integration", -) -> tuple[list[str], dict]: - session_id = str(session_id) - - get_query = f""" - input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]] - - ?[values] := - input[session_id, agent_id], - *sessions {{ - session_id, - metadata: session_metadata, - }}, - *agents {{ - agent_id, - metadata: agent_metadata, - }}, - session_{arg_type} = get(session_metadata, "x-{tool_type}-{arg_type}", {{}}), - agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}), - - # Right values overwrite left values - # See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat - values = concat(agent_{arg_type}, session_{arg_type}), - - :limit 1 - """ - - queries = [ - get_query, - ] +if not tools_args_for_task_query.is_valid(): + raise InvalidSQLQuery("tools_args_for_task_query") + +tool_args_for_session_query = sqlvalidator.parse( + """SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM ( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md + FROM agents + WHERE agent_id = $1 AND developer_id = $4 LIMIT 1 +) AS agents_md, +( + SELECT + CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args' + WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args' + WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup' + WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md + FROM sessions + WHERE session_id = $2 AND developer_id = $4 LIMIT 1 +) AS sessions_md""" +) - return (queries, {"agent_id": agent_id, "session_id": session_id}) +if not tool_args_for_session_query.is_valid(): + raise InvalidSQLQuery("tool_args_for_session") # @rewrap_exceptions( @@ -108,25 +77,23 @@ def get_tool_args_from_metadata( task_id: UUID | None = None, tool_type: Literal["integration", "api_call"] = "integration", arg_type: Literal["args", "setup", "headers"] = "args", -) -> tuple[list[str], dict]: - common: dict = dict( - developer_id=developer_id, - agent_id=agent_id, - tool_type=tool_type, - arg_type=arg_type, - ) - +) -> tuple[list[str], list]: match session_id, task_id: case (None, task_id) if task_id is not None: - return tool_args_for_task( - **common, - task_id=task_id, + return ( + tools_args_for_task_query.format(), + [ + agent_id, + task_id, + f"x-{tool_type}-{arg_type}", + developer_id, + ], ) case (session_id, None) if session_id is not None: - return tool_args_for_session( - **common, - session_id=session_id, + return ( + tool_args_for_session_query.format(), + [agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id], ) case (_, _):