Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api,memory-store): Tasks queries #978

Merged
merged 6 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions agents-api/agents_api/autogen/Tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,21 @@ class CreateTaskRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: str
name: Annotated[str, Field(max_length=255, min_length=1)]
"""
The name of the task.
"""
canonical_name: Annotated[
str | None,
Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"),
] = None
"""
The canonical name of the task.
"""
description: str = ""
"""
The description of the task.
"""
main: Annotated[
list[
EvaluateStep
Expand Down Expand Up @@ -650,7 +663,21 @@ class PatchTaskRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str | None, Field(max_length=255, min_length=1)] = None
"""
The name of the task.
"""
canonical_name: Annotated[
str | None,
Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"),
] = None
"""
The canonical name of the task.
"""
description: str = ""
"""
The description of the task.
"""
main: Annotated[
list[
EvaluateStep
Expand Down Expand Up @@ -966,8 +993,21 @@ class Task(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: str
name: Annotated[str, Field(max_length=255, min_length=1)]
"""
The name of the task.
"""
canonical_name: Annotated[
str | None,
Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"),
] = None
"""
The canonical name of the task.
"""
description: str = ""
"""
The description of the task.
"""
main: Annotated[
list[
EvaluateStep
Expand Down Expand Up @@ -1124,7 +1164,21 @@ class UpdateTaskRequest(BaseModel):
model_config = ConfigDict(
populate_by_name=True,
)
name: Annotated[str, Field(max_length=255, min_length=1)]
"""
The name of the task.
"""
canonical_name: Annotated[
str | None,
Field(max_length=255, min_length=1, pattern="^[a-zA-Z][a-zA-Z0-9_]*$"),
] = None
"""
The canonical name of the task.
"""
description: str = ""
"""
The description of the task.
"""
main: Annotated[
list[
EvaluateStep
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/agents/create_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...autogen.openapi_model import Agent, CreateAgentRequest
from ...autogen.openapi_model import CreateAgentRequest, ResourceCreatedResponse
from ...metrics.counters import increase_counter
from ..utils import (
generate_canonical_name,
Expand Down Expand Up @@ -75,9 +75,9 @@
# }
# )
@wrap_in_class(
Agent,
ResourceCreatedResponse,
one=True,
transform=lambda d: {"id": d["agent_id"], **d},
transform=lambda d: {"id": d["agent_id"], "created_at": d["created_at"]},
)
@increase_counter("create_agent")
@pg_query
Expand Down
26 changes: 17 additions & 9 deletions agents-api/agents_api/queries/agents/create_or_update_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

# Define the raw SQL query
agent_query = parse_one("""
WITH existing_agent AS (
SELECT canonical_name
FROM agents
WHERE developer_id = $1 AND agent_id = $2
)
INSERT INTO agents (
developer_id,
agent_id,
Expand All @@ -30,15 +35,18 @@
default_settings
)
VALUES (
$1,
$2,
$3,
$4,
$5,
$6,
$7,
$8,
$9
$1, -- developer_id
$2, -- agent_id
COALESCE( -- canonical_name
(SELECT canonical_name FROM existing_agent),
$3
),
$4, -- name
$5, -- about
$6, -- instructions
$7, -- model
$8, -- metadata
$9 -- default_settings
)
RETURNING *;
""").sql(pretty=True)
Expand Down
8 changes: 6 additions & 2 deletions agents-api/agents_api/queries/developers/create_developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...common.protocol.developers import Developer
from ...autogen.openapi_model import ResourceCreatedResponse
from ..utils import (
partialclass,
pg_query,
Expand Down Expand Up @@ -43,7 +43,11 @@
)
}
)
@wrap_in_class(Developer, one=True, transform=lambda d: {**d, "id": d["developer_id"]})
@wrap_in_class(
ResourceCreatedResponse,
one=True,
transform=lambda d: {**d, "id": d["developer_id"], "created_at": d["created_at"]},
)
@pg_query
@beartype
async def create_developer(
Expand Down
9 changes: 7 additions & 2 deletions agents-api/agents_api/queries/entries/create_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from litellm.utils import _select_tokenizer as select_tokenizer
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateEntryRequest, Entry, Relation
from ...autogen.openapi_model import (
CreateEntryRequest,
Relation,
ResourceCreatedResponse,
)
from ...common.utils.datetime import utcnow
from ...common.utils.messages import content_to_json
from ...metrics.counters import increase_counter
Expand Down Expand Up @@ -79,9 +83,10 @@
}
)
@wrap_in_class(
Entry,
ResourceCreatedResponse,
transform=lambda d: {
"id": d.pop("entry_id"),
"created_at": d.pop("created_at"),
**d,
},
)
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/queries/entries/get_history.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from typing import Any, List, Tuple
from uuid import UUID

import asyncpg
Expand Down
12 changes: 4 additions & 8 deletions agents-api/agents_api/queries/entries/list_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@

# Query for checking if the session exists
session_exists_query = """
SELECT CASE
WHEN EXISTS (
SELECT 1 FROM sessions
WHERE session_id = $1 AND developer_id = $2
)
THEN TRUE
ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error
END;
SELECT EXISTS (
SELECT 1 FROM sessions
WHERE session_id = $1 AND developer_id = $2
) AS exists;
"""

list_entries_query = """
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/queries/files/create_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@

import base64
import hashlib
from typing import Any, Literal
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one
from uuid_extensions import uuid7

from ...autogen.openapi_model import CreateFileRequest, File
from ...metrics.counters import increase_counter
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, wrap_in_class

# Create file
file_query = parse_one("""
Expand Down
4 changes: 1 addition & 3 deletions agents-api/agents_api/queries/files/get_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import File
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, wrap_in_class

# Define the raw SQL query
file_query = parse_one("""
Expand Down
5 changes: 2 additions & 3 deletions agents-api/agents_api/queries/files/list_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
It constructs and executes SQL queries to fetch a list of files based on developer ID with pagination.
"""

from typing import Any, Literal
from typing import Literal
from uuid import UUID

import asyncpg
from beartype import beartype
from fastapi import HTTPException
from sqlglot import parse_one

from ...autogen.openapi_model import File
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
from ..utils import pg_query, wrap_in_class

# Query to list all files for a developer (uses developer_id index)
developer_files_query = parse_one("""
Expand Down
7 changes: 4 additions & 3 deletions agents-api/agents_api/queries/sessions/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from uuid_extensions import uuid7

from ...autogen.openapi_model import (
CreateSessionRequest,
Session,
Create`Request,
ResourceCreatedResponse,
)
from ...metrics.counters import increase_counter
from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class
Expand Down Expand Up @@ -68,11 +68,12 @@
}
)
@wrap_in_class(
Session,
ResourceCreatedResponse,
one=True,
transform=lambda d: {
**d,
"id": d["session_id"],
"created_at": d["created_at"],
},
)
@increase_counter("create_session")
Expand Down
28 changes: 28 additions & 0 deletions agents-api/agents_api/queries/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
The `task` module within the `queries` package provides SQL query functions for managing tasks
in the TimescaleDB database. This includes operations for:

- Creating new tasks
- Updating existing tasks
- Retrieving task details
- Listing tasks with filtering and pagination
- Deleting tasks
"""

from .create_or_update_task import create_or_update_task
from .create_task import create_task
from .delete_task import delete_task
from .get_task import get_task
from .list_tasks import list_tasks
from .patch_task import patch_task
from .update_task import update_task

__all__ = [
"create_or_update_task",
"create_task",
"delete_task",
"get_task",
"list_tasks",
"patch_task",
"update_task",
]
Loading
Loading