Skip to content

Commit

Permalink
feat: Add tools args from metadata query
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 23, 2024
1 parent 0252a88 commit d209e77
Showing 1 changed file with 59 additions and 92 deletions.
151 changes: 59 additions & 92 deletions agents-api/agents_api/queries/tools/get_tool_args_from_metadata.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,62 @@
from typing import Literal
from uuid import UUID

import sqlvalidator
from beartype import beartype

from ...exceptions import InvalidSQLQuery
from ..utils import (
pg_query,
wrap_in_class,
)

tools_args_for_task_query = sqlvalidator.parse(
"""SELECT COALESCE(agents_md || tasks_md, agents_md, tasks_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup'
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md
FROM agents
WHERE agent_id = $1 AND developer_id = $4 LIMIT 1
) AS agents_md,
(
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup'
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM tasks
WHERE task_id = $2 AND developer_id = $4 LIMIT 1
) AS tasks_md"""
)

def tool_args_for_task(
*,
developer_id: UUID,
agent_id: UUID,
task_id: UUID,
tool_type: Literal["integration", "api_call"] = "integration",
arg_type: Literal["args", "setup"] = "args",
) -> tuple[list[str], dict]:
agent_id = str(agent_id)
task_id = str(task_id)

get_query = f"""
input[agent_id, task_id] <- [[to_uuid($agent_id), to_uuid($task_id)]]
?[values] :=
input[agent_id, task_id],
*tasks {{
task_id,
metadata: task_metadata,
}},
*agents {{
agent_id,
metadata: agent_metadata,
}},
task_{arg_type} = get(task_metadata, "x-{tool_type}-{arg_type}", {{}}),
agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}),
# Right values overwrite left values
# See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat
values = concat(agent_{arg_type}, task_{arg_type}),
:limit 1
"""

queries = [
get_query,
]

return (queries, {"agent_id": agent_id, "task_id": task_id})


def tool_args_for_session(
*,
developer_id: UUID,
session_id: UUID,
agent_id: UUID,
arg_type: Literal["args", "setup"] = "args",
tool_type: Literal["integration", "api_call"] = "integration",
) -> tuple[list[str], dict]:
session_id = str(session_id)

get_query = f"""
input[session_id, agent_id] <- [[to_uuid($session_id), to_uuid($agent_id)]]
?[values] :=
input[session_id, agent_id],
*sessions {{
session_id,
metadata: session_metadata,
}},
*agents {{
agent_id,
metadata: agent_metadata,
}},
session_{arg_type} = get(session_metadata, "x-{tool_type}-{arg_type}", {{}}),
agent_{arg_type} = get(agent_metadata, "x-{tool_type}-{arg_type}", {{}}),
# Right values overwrite left values
# See: https://docs.cozodb.org/en/latest/functions.html#Func.Vector.concat
values = concat(agent_{arg_type}, session_{arg_type}),
:limit 1
"""

queries = [
get_query,
]
if not tools_args_for_task_query.is_valid():
raise InvalidSQLQuery("tools_args_for_task_query")

tool_args_for_session_query = sqlvalidator.parse(
"""SELECT COALESCE(agents_md || sessions_md, agents_md, sessions_md, '{}') as values FROM (
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup'
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS agents_md
FROM agents
WHERE agent_id = $1 AND developer_id = $4 LIMIT 1
) AS agents_md,
(
SELECT
CASE WHEN $3 = 'x-integrations-args' then metadata->'x-integrations-args'
WHEN $3 = 'x-api_call-args' then metadata->'x-api_call-args'
WHEN $3 = 'x-integrations-setup' then metadata->'x-integrations-setup'
WHEN $3 = 'x-api_call-setup' then metadata->'x-api_call-setup' END AS tasks_md
FROM sessions
WHERE session_id = $2 AND developer_id = $4 LIMIT 1
) AS sessions_md"""
)

return (queries, {"agent_id": agent_id, "session_id": session_id})
if not tool_args_for_session_query.is_valid():
raise InvalidSQLQuery("tool_args_for_session")


# @rewrap_exceptions(
Expand All @@ -108,25 +77,23 @@ def get_tool_args_from_metadata(
task_id: UUID | None = None,
tool_type: Literal["integration", "api_call"] = "integration",
arg_type: Literal["args", "setup", "headers"] = "args",
) -> tuple[list[str], dict]:
common: dict = dict(
developer_id=developer_id,
agent_id=agent_id,
tool_type=tool_type,
arg_type=arg_type,
)

) -> tuple[list[str], list]:
match session_id, task_id:
case (None, task_id) if task_id is not None:
return tool_args_for_task(
**common,
task_id=task_id,
return (
tools_args_for_task_query.format(),
[
agent_id,
task_id,
f"x-{tool_type}-{arg_type}",
developer_id,
],
)

case (session_id, None) if session_id is not None:
return tool_args_for_session(
**common,
session_id=session_id,
return (
tool_args_for_session_query.format(),
[agent_id, session_id, f"x-{tool_type}-{arg_type}", developer_id],
)

case (_, _):
Expand Down

0 comments on commit d209e77

Please sign in to comment.