Skip to content

Commit

Permalink
feat: Add patch tool query
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 20, 2024
1 parent 83f58ac commit 59b24ac
Showing 1 changed file with 43 additions and 51 deletions.
94 changes: 43 additions & 51 deletions agents-api/agents_api/queries/tools/patch_tool.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,61 @@
from typing import Any, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import PatchToolRequest, ResourceUpdatedResponse
from ...common.utils.cozo import cozo_process_mutate_data
from ...exceptions import InvalidSQLQuery
from ...metrics.counters import increase_counter
from ..utils import (
cozo_query,
partialclass,
rewrap_exceptions,
verify_developer_id_query,
verify_developer_owns_resource_query,
pg_query,
wrap_in_class,
)

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
ValidationError: partialclass(HTTPException, status_code=400),
TypeError: partialclass(HTTPException, status_code=400),
}
sql_query = sqlvalidator.parse("""
WITH updated_tools AS (
UPDATE tools
SET
type = COALESCE($4, type),
name = COALESCE($5, name),
description = COALESCE($6, description),
spec = COALESCE($7, spec)
WHERE
developer_id = $1 AND
agent_id = $2 AND
tool_id = $3
RETURNING *
)
SELECT * FROM updated_tools;
""")

if not sql_query.is_valid():
raise InvalidSQLQuery("patch_tool")


# @rewrap_exceptions(
# {
# QueryException: partialclass(HTTPException, status_code=400),
# ValidationError: partialclass(HTTPException, status_code=400),
# TypeError: partialclass(HTTPException, status_code=400),
# }
# )
@wrap_in_class(
ResourceUpdatedResponse,
one=True,
transform=lambda d: {"id": d["tool_id"], "jobs": [], **d},
_kind="inserted",
)
@cozo_query
@pg_query
@increase_counter("patch_tool")
@beartype
def patch_tool(
*, developer_id: UUID, agent_id: UUID, tool_id: UUID, data: PatchToolRequest
) -> tuple[list[str], dict]:
) -> tuple[list[str], list]:
"""
Execute the datalog query and return the results as a DataFrame
Updates the tool information for a given agent and tool ID in the 'cozodb' database.
Expand All @@ -54,6 +69,7 @@ def patch_tool(
ResourceUpdatedResponse: The updated tool data.
"""

developer_id = str(developer_id)
agent_id = str(agent_id)
tool_id = str(tool_id)

Expand All @@ -78,39 +94,15 @@ def patch_tool(
if tool_spec:
del patch_data[tool_type]

tool_cols, tool_vals = cozo_process_mutate_data(
{
**patch_data,
"agent_id": agent_id,
"tool_id": tool_id,
}
)

# Construct the datalog query for updating the tool information
patch_query = f"""
input[{tool_cols}] <- $input
?[{tool_cols}, spec, updated_at] :=
*tools {{
agent_id: to_uuid($agent_id),
tool_id: to_uuid($tool_id),
spec: old_spec,
}},
input[{tool_cols}],
spec = concat(old_spec, $spec),
updated_at = now()
:update tools {{ {tool_cols}, spec, updated_at }}
:returning
"""

queries = [
verify_developer_id_query(developer_id),
verify_developer_owns_resource_query(developer_id, "agents", agent_id=agent_id),
patch_query,
]

return (
queries,
dict(input=tool_vals, spec=tool_spec, agent_id=agent_id, tool_id=tool_id),
sql_query.format(),
[
developer_id,
agent_id,
tool_id,
tool_type,
data.name,
data.description,
tool_spec,
],
)

0 comments on commit 59b24ac

Please sign in to comment.