diff --git a/agents-api/agents_api/autogen/Sessions.py b/agents-api/agents_api/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/agents-api/agents_api/autogen/Sessions.py +++ b/agents-api/agents_api/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/agents-api/agents_api/queries/agents/create_agent.py b/agents-api/agents_api/queries/agents/create_agent.py index 81a408f30..bb111b0df 100644 --- a/agents-api/agents_api/queries/agents/create_agent.py +++ b/agents-api/agents_api/queries/agents/create_agent.py @@ -7,18 +7,17 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException -from pydantic import ValidationError from sqlglot import parse_one from uuid_extensions import uuid7 +from ...metrics.counters import increase_counter + from ...autogen.openapi_model import Agent, CreateAgentRequest from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/create_or_update_agent.py b/agents-api/agents_api/queries/agents/create_or_update_agent.py index d74cd57c2..6cfb83767 100644 --- a/agents-api/agents_api/queries/agents/create_or_update_agent.py +++ b/agents-api/agents_api/queries/agents/create_or_update_agent.py @@ -7,17 +7,15 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent, CreateOrUpdateAgentRequest +from ...metrics.counters import increase_counter from ..utils import ( generate_canonical_name, - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/delete_agent.py b/agents-api/agents_api/queries/agents/delete_agent.py index db4a3ab4f..9c3ee5585 100644 --- a/agents-api/agents_api/queries/agents/delete_agent.py +++ b/agents-api/agents_api/queries/agents/delete_agent.py @@ -7,16 +7,15 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceDeletedResponse +from ...metrics.counters import increase_counter +from ...common.utils.datetime import utcnow from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/get_agent.py b/agents-api/agents_api/queries/agents/get_agent.py index a9893d747..dce424771 100644 --- a/agents-api/agents_api/queries/agents/get_agent.py +++ b/agents-api/agents_api/queries/agents/get_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) raw_query = """ diff --git a/agents-api/agents_api/queries/agents/list_agents.py b/agents-api/agents_api/queries/agents/list_agents.py index 69e91f206..3698c68f1 100644 --- a/agents-api/agents_api/queries/agents/list_agents.py +++ b/agents-api/agents_api/queries/agents/list_agents.py @@ -8,15 +8,13 @@ from beartype import beartype from fastapi import HTTPException -from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import Agent +from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/patch_agent.py b/agents-api/agents_api/queries/agents/patch_agent.py index d2a172838..6f9cb3b9c 100644 --- a/agents-api/agents_api/queries/agents/patch_agent.py +++ b/agents-api/agents_api/queries/agents/patch_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/agents/update_agent.py b/agents-api/agents_api/queries/agents/update_agent.py index d03994e9c..cd15313a2 100644 --- a/agents-api/agents_api/queries/agents/update_agent.py +++ b/agents-api/agents_api/queries/agents/update_agent.py @@ -7,17 +7,14 @@ from uuid import UUID from beartype import beartype -from fastapi import HTTPException from sqlglot import parse_one -from sqlglot.optimizer import optimize from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest from ...metrics.counters import increase_counter from ..utils import ( - partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) ModelT = TypeVar("ModelT", bound=Any) diff --git a/agents-api/agents_api/queries/developers/get_developer.py b/agents-api/agents_api/queries/developers/get_developer.py index 373a2fb36..28be9a4b1 100644 --- a/agents-api/agents_api/queries/developers/get_developer.py +++ b/agents-api/agents_api/queries/developers/get_developer.py @@ -12,8 +12,8 @@ from ..utils import ( partialclass, pg_query, - rewrap_exceptions, wrap_in_class, + rewrap_exceptions, ) # TODO: Add verify_developer diff --git a/agents-api/agents_api/queries/entries/create_entries.py b/agents-api/agents_api/queries/entries/create_entries.py index 24c0be26e..a54104274 100644 --- a/agents-api/agents_api/queries/entries/create_entries.py +++ b/agents-api/agents_api/queries/entries/create_entries.py @@ -14,14 +14,10 @@ # Query for checking if the session exists session_exists_query = """ -SELECT CASE - WHEN EXISTS ( - SELECT 1 FROM sessions - WHERE session_id = $1 AND developer_id = $2 - ) - THEN TRUE - ELSE (SELECT NULL::boolean WHERE FALSE) -- This raises a NO_DATA_FOUND error -END; +SELECT EXISTS ( + SELECT 1 FROM sessions + WHERE session_id = $1 AND developer_id = $2 +) AS exists; """ # Define the raw SQL query for creating entries @@ -71,6 +67,10 @@ status_code=400, detail=str(exc), ), + asyncpg.NoDataFoundError: lambda exc: HTTPException( + status_code=404, + detail="Session not found", + ), } ) @wrap_in_class( @@ -166,7 +166,7 @@ async def add_entry_relations( item.get("is_leaf", False), # $5 ] ) - + return [ ( session_exists_query, diff --git a/agents-api/agents_api/queries/entries/list_entries.py b/agents-api/agents_api/queries/entries/list_entries.py index 0aeb92a25..3f4a0699e 100644 --- a/agents-api/agents_api/queries/entries/list_entries.py +++ b/agents-api/agents_api/queries/entries/list_entries.py @@ -62,6 +62,10 @@ status_code=400, detail=str(exc), ), + asyncpg.NoDataFoundError: lambda exc: HTTPException( + status_code=404, + detail="Session not found", + ), } ) @wrap_in_class(Entry) @@ -78,7 +82,7 @@ async def list_entries( sort_by: Literal["created_at", "timestamp"] = "timestamp", direction: Literal["asc", "desc"] = "asc", exclude_relations: list[str] = [], -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: if limit < 1 or limit > 1000: raise HTTPException(status_code=400, detail="Limit must be between 1 and 1000") if offset < 0: @@ -98,14 +102,14 @@ async def list_entries( developer_id, # $5 exclude_relations, # $6 ] - return [ ( session_exists_query, [session_id, developer_id], + "fetchrow", ), ( query, - entry_params, + entry_params ), ] diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index 3074f087b..baa3f09d1 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -45,11 +45,7 @@ participant_type, participant_id ) -SELECT - $1 as developer_id, - $2 as session_id, - unnest($3::participant_type[]) as participant_type, - unnest($4::uuid[]) as participant_id; +VALUES ($1, $2, $3, $4); """).sql(pretty=True) @@ -67,7 +63,7 @@ ), } ) -@wrap_in_class(Session, one=True, transform=lambda d: {**d, "id": d["session_id"]}) +@wrap_in_class(Session, transform=lambda d: {**d, "id": d["session_id"]}) @increase_counter("create_session") @pg_query @beartype @@ -76,7 +72,7 @@ async def create_session( developer_id: UUID, session_id: UUID, data: CreateSessionRequest, -) -> list[tuple[str, list]]: +) -> list[tuple[str, list] | tuple[str, list, str]]: """ Constructs SQL queries to create a new session and its participant lookups. @@ -86,7 +82,7 @@ async def create_session( data (CreateSessionRequest): Session creation data Returns: - list[tuple[str, list]]: SQL queries and their parameters + list[tuple[str, list] | tuple[str, list, str]]: SQL queries and their parameters """ # Handle participants users = data.users or ([data.user] if data.user else []) @@ -122,15 +118,15 @@ async def create_session( data.recall_options or {}, # $10 ] - # Prepare lookup parameters - lookup_params = [ - developer_id, # $1 - session_id, # $2 - participant_types, # $3 - participant_ids, # $4 - ] + # Prepare lookup parameters as a list of parameter lists + lookup_params = [] + for ptype, pid in zip(participant_types, participant_ids): + lookup_params.append([developer_id, session_id, ptype, pid]) + print("*" * 100) + print(lookup_params) + print("*" * 100) return [ (session_query, session_params), - (lookup_query, lookup_params), + (lookup_query, lookup_params, "fetchmany"), ] diff --git a/agents-api/agents_api/queries/utils.py b/agents-api/agents_api/queries/utils.py index ba9bade9e..194cba7bc 100644 --- a/agents-api/agents_api/queries/utils.py +++ b/agents-api/agents_api/queries/utils.py @@ -69,7 +69,7 @@ class AsyncPGFetchArgs(TypedDict): type SQLQuery = str -type FetchMethod = Literal["fetch", "fetchmany"] +type FetchMethod = Literal["fetch", "fetchmany", "fetchrow"] type PGQueryArgs = tuple[SQLQuery, list[Any]] | tuple[SQLQuery, list[Any], FetchMethod] type PreparedPGQueryArgs = tuple[FetchMethod, AsyncPGFetchArgs] type BatchedPreparedPGQueryArgs = list[PreparedPGQueryArgs] @@ -102,6 +102,13 @@ def prepare_pg_query_args( ), ) ) + case (query, variables, "fetchrow"): + batch.append( + ( + "fetchrow", + AsyncPGFetchArgs(query=query, args=variables, timeout=query_timeout), + ) + ) case _: raise ValueError("Invalid query arguments") @@ -161,6 +168,14 @@ async def wrapper( query, *args, timeout=timeout ) + print("%" * 100) + print(results) + print(*args) + print("%" * 100) + + if method_name == "fetchrow" and (len(results) == 0 or results.get("bool") is None): + raise asyncpg.NoDataFoundError + end = timeit and time.perf_counter() timeit and print( diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 25892d959..9153785a4 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -1,5 +1,6 @@ import random import string +import time from uuid import UUID from fastapi.testclient import TestClient @@ -7,6 +8,8 @@ from ward import fixture from agents_api.autogen.openapi_model import ( + CreateAgentRequest, + CreateSessionRequest, CreateUserRequest, ) from agents_api.clients.pg import create_db_pool @@ -24,8 +27,8 @@ # from agents_api.queries.execution.create_temporal_lookup import create_temporal_lookup # from agents_api.queries.files.create_file import create_file # from agents_api.queries.files.delete_file import delete_file -# from agents_api.queries.session.create_session import create_session -# from agents_api.queries.session.delete_session import delete_session +from agents_api.queries.sessions.create_session import create_session + # from agents_api.queries.task.create_task import create_task # from agents_api.queries.task.delete_task import delete_task # from agents_api.queries.tools.create_tools import create_tools @@ -150,22 +153,27 @@ async def test_new_developer(dsn=pg_dsn, email=random_email): return developer -# @fixture(scope="global") -# async def test_session( -# dsn=pg_dsn, -# developer_id=test_developer_id, -# test_user=test_user, -# test_agent=test_agent, -# ): -# async with get_pg_client(dsn=dsn) as client: -# session = await create_session( -# developer_id=developer_id, -# data=CreateSessionRequest( -# agent=test_agent.id, user=test_user.id, metadata={"test": "test"} -# ), -# client=client, -# ) -# yield session +@fixture(scope="global") +async def test_session( + dsn=pg_dsn, + developer_id=test_developer_id, + test_user=test_user, + test_agent=test_agent, +): + pool = await create_db_pool(dsn=dsn) + + session = await create_session( + developer_id=developer_id, + data=CreateSessionRequest( + agent=test_agent.id, + user=test_user.id, + metadata={"test": "test"}, + system_template="test system template", + ), + connection_pool=pool, + ) + + return session # @fixture(scope="global") diff --git a/agents-api/tests/test_agent_queries.py b/agents-api/tests/test_agent_queries.py index 56a07ed03..b6cb7aedc 100644 --- a/agents-api/tests/test_agent_queries.py +++ b/agents-api/tests/test_agent_queries.py @@ -1,7 +1,5 @@ # Tests for agent queries -from uuid import UUID -import asyncpg from uuid_extensions import uuid7 from ward import raises, test diff --git a/agents-api/tests/test_entry_queries.py b/agents-api/tests/test_entry_queries.py index 87d9cdb4f..da53ce06d 100644 --- a/agents-api/tests/test_entry_queries.py +++ b/agents-api/tests/test_entry_queries.py @@ -3,7 +3,7 @@ It verifies the functionality of adding, retrieving, and processing entries as defined in the schema. """ -from uuid import uuid4 +from uuid_extensions import uuid7 from fastapi import HTTPException from ward import raises, test @@ -11,7 +11,7 @@ from agents_api.autogen.openapi_model import CreateEntryRequest from agents_api.clients.pg import create_db_pool from agents_api.queries.entries import create_entries, list_entries -from tests.fixtures import pg_dsn, test_developer # , test_session +from tests.fixtures import pg_dsn, test_developer, test_session # , test_session MODEL = "gpt-4o-mini" @@ -31,11 +31,10 @@ async def _(dsn=pg_dsn, developer=test_developer): with raises(HTTPException) as exc_info: await create_entries( developer_id=developer.id, - session_id=uuid4(), + session_id=uuid7(), data=[test_entry], connection_pool=pool, ) - assert exc_info.raised.status_code == 404 @@ -48,10 +47,9 @@ async def _(dsn=pg_dsn, developer=test_developer): with raises(HTTPException) as exc_info: await list_entries( developer_id=developer.id, - session_id=uuid4(), + session_id=uuid7(), connection_pool=pool, ) - assert exc_info.raised.status_code == 404 diff --git a/agents-api/tests/test_messages_truncation.py b/agents-api/tests/test_messages_truncation.py index 39cc02c2c..bb1eaee30 100644 --- a/agents-api/tests/test_messages_truncation.py +++ b/agents-api/tests/test_messages_truncation.py @@ -1,4 +1,4 @@ -# from uuid import uuid4 + # from uuid_extensions import uuid7 # from ward import raises, test diff --git a/agents-api/tests/test_session_queries.py b/agents-api/tests/test_session_queries.py index 4fdc7e6e4..b85268434 100644 --- a/agents-api/tests/test_session_queries.py +++ b/agents-api/tests/test_session_queries.py @@ -8,79 +8,116 @@ from agents_api.autogen.openapi_model import ( Session, + CreateSessionRequest, + CreateOrUpdateSessionRequest, + UpdateSessionRequest, + PatchSessionRequest, + ResourceUpdatedResponse, + ResourceDeletedResponse, ) from agents_api.clients.pg import create_db_pool from agents_api.queries.sessions import ( count_sessions, get_session, list_sessions, + create_session, + create_or_update_session, + update_session, + patch_session, + delete_session, ) from tests.fixtures import ( pg_dsn, test_developer_id, -) # , test_session, test_agent, test_user - -# @test("query: create session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): -# """Test that a session can be successfully created.""" - -# pool = await create_db_pool(dsn=dsn) -# await create_session( -# developer_id=developer_id, -# session_id=uuid7(), -# data=CreateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session", -# ), -# connection_pool=pool, -# ) - - -# @test("query: create or update session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user): -# """Test that a session can be successfully created or updated.""" - -# pool = await create_db_pool(dsn=dsn) -# await create_or_update_session( -# developer_id=developer_id, -# session_id=uuid7(), -# data=CreateOrUpdateSessionRequest( -# users=[user.id], -# agents=[agent.id], -# situation="test session", -# ), -# connection_pool=pool, -# ) - - -# @test("query: update session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): -# """Test that an existing session's information can be successfully updated.""" - -# pool = await create_db_pool(dsn=dsn) -# update_result = await update_session( -# session_id=session.id, -# developer_id=developer_id, -# data=UpdateSessionRequest( -# agents=[agent.id], -# situation="updated session", -# ), -# connection_pool=pool, -# ) - -# assert update_result is not None -# assert isinstance(update_result, ResourceUpdatedResponse) -# assert update_result.updated_at > session.created_at - - -@test("query: get session not exists sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that retrieving a non-existent session returns an empty result.""" + test_developer, + test_user, + test_agent, + test_session, +) + +@test("query: create session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created.""" + + pool = await create_db_pool(dsn=dsn) session_id = uuid7() + data = CreateSessionRequest( + users=[user.id], + agents=[agent.id], + situation="test session", + system_template="test system template", + ) + result = await create_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session), f"Result is not a Session, {result}" + assert result.id == session_id + assert result.developer_id == developer_id + assert result.situation == "test session" + assert set(result.users) == {user.id} + assert set(result.agents) == {agent.id} + + +@test("query: create or update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, agent=test_agent, user=test_user +): + """Test that a session can be successfully created or updated.""" + pool = await create_db_pool(dsn=dsn) + session_id = uuid7() + data = CreateOrUpdateSessionRequest( + users=[user.id], + agents=[agent.id], + situation="test session", + ) + result = await create_or_update_session( + developer_id=developer_id, + session_id=session_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session) + assert result.id == session_id + assert result.developer_id == developer_id + assert result.situation == "test session" + assert set(result.users) == {user.id} + assert set(result.agents) == {agent.id} + + +@test("query: get session exists") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test retrieving an existing session.""" + pool = await create_db_pool(dsn=dsn) + result = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, Session) + assert result.id == session.id + assert result.developer_id == developer_id + + +@test("query: get session does not exist") +async def _(dsn=pg_dsn, developer_id=test_developer_id): + """Test retrieving a non-existent session.""" + + session_id = uuid7() + pool = await create_db_pool(dsn=dsn) with raises(Exception): await get_session( session_id=session_id, @@ -89,90 +126,136 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id): ) -# @test("query: get session exists sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test that retrieving an existing session returns the correct session information.""" +@test("query: list sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with default pagination.""" -# pool = await create_db_pool(dsn=dsn) -# result = await get_session( -# session_id=session.id, -# developer_id=developer_id, -# connection_pool=pool, -# ) + pool = await create_db_pool(dsn=dsn) + result, _ = await list_sessions( + developer_id=developer_id, + limit=10, + offset=0, + connection_pool=pool, + ) -# assert result is not None -# assert isinstance(result, Session) + assert isinstance(result, list) + assert len(result) >= 1 + assert any(s.id == session.id for s in result) -@test("query: list sessions when none exist sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that listing sessions returns a collection of session information.""" +@test("query: list sessions with filters") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test listing sessions with specific filters.""" pool = await create_db_pool(dsn=dsn) - result = await list_sessions( + result, _ = await list_sessions( developer_id=developer_id, + limit=10, + offset=0, + filters={"situation": "test session"}, connection_pool=pool, ) assert isinstance(result, list) assert len(result) >= 1 - assert all(isinstance(session, Session) for session in result) - - -# @test("query: patch session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent): -# """Test that a session can be successfully patched.""" - -# pool = await create_db_pool(dsn=dsn) -# patch_result = await patch_session( -# developer_id=developer_id, -# session_id=session.id, -# data=PatchSessionRequest( -# agents=[agent.id], -# situation="patched session", -# metadata={"test": "metadata"}, -# ), -# connection_pool=pool, -# ) - -# assert patch_result is not None -# assert isinstance(patch_result, ResourceUpdatedResponse) -# assert patch_result.updated_at > session.created_at - - -# @test("query: delete session sql") -# async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): -# """Test that a session can be successfully deleted.""" - -# pool = await create_db_pool(dsn=dsn) -# delete_result = await delete_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) - -# assert delete_result is not None -# assert isinstance(delete_result, ResourceDeletedResponse) - -# # Verify the session no longer exists -# with raises(Exception): -# await get_session( -# developer_id=developer_id, -# session_id=session.id, -# connection_pool=pool, -# ) - - -@test("query: count sessions sql") -async def _(dsn=pg_dsn, developer_id=test_developer_id): - """Test that sessions can be counted.""" + assert all(s.situation == "test session" for s in result) + + +@test("query: count sessions") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test counting the number of sessions for a developer.""" pool = await create_db_pool(dsn=dsn) - result = await count_sessions( + count = await count_sessions( developer_id=developer_id, connection_pool=pool, ) - assert isinstance(result, dict) - assert "count" in result - assert isinstance(result["count"], int) + assert isinstance(count, int) + assert count >= 1 + + +@test("query: update session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that an existing session's information can be successfully updated.""" + + pool = await create_db_pool(dsn=dsn) + data = UpdateSessionRequest( + agents=[agent.id], + situation="updated session", + ) + result = await update_session( + session_id=session.id, + developer_id=developer_id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + updated_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert updated_session.situation == "updated session" + assert set(updated_session.agents) == {agent.id} + + +@test("query: patch session sql") +async def _( + dsn=pg_dsn, developer_id=test_developer_id, session=test_session, agent=test_agent +): + """Test that a session can be successfully patched.""" + + pool = await create_db_pool(dsn=dsn) + data = PatchSessionRequest( + agents=[agent.id], + situation="patched session", + metadata={"test": "metadata"}, + ) + result = await patch_session( + developer_id=developer_id, + session_id=session.id, + data=data, + connection_pool=pool, + ) + + assert result is not None + assert isinstance(result, ResourceUpdatedResponse) + assert result.updated_at > session.created_at + + patched_session = await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + assert patched_session.situation == "patched session" + assert set(patched_session.agents) == {agent.id} + assert patched_session.metadata == {"test": "metadata"} + + +@test("query: delete session sql") +async def _(dsn=pg_dsn, developer_id=test_developer_id, session=test_session): + """Test that a session can be successfully deleted.""" + + pool = await create_db_pool(dsn=dsn) + delete_result = await delete_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) + + assert delete_result is not None + assert isinstance(delete_result, ResourceDeletedResponse) + + with raises(Exception): + await get_session( + developer_id=developer_id, + session_id=session.id, + connection_pool=pool, + ) diff --git a/integrations-service/integrations/autogen/Sessions.py b/integrations-service/integrations/autogen/Sessions.py index 460fd25ce..e2a9ce164 100644 --- a/integrations-service/integrations/autogen/Sessions.py +++ b/integrations-service/integrations/autogen/Sessions.py @@ -31,6 +31,10 @@ class CreateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -51,6 +55,10 @@ class CreateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -67,6 +75,10 @@ class PatchSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -87,6 +99,10 @@ class PatchSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptionsUpdate | None = None metadata: dict[str, Any] | None = None @@ -121,6 +137,10 @@ class Session(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ summary: Annotated[str | None, Field(json_schema_extra={"readOnly": True})] = None """ Summary (null at the beginning) - generated automatically after every interaction @@ -145,6 +165,10 @@ class Session(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None id: Annotated[UUID, Field(json_schema_extra={"readOnly": True})] metadata: dict[str, Any] | None = None @@ -197,6 +221,10 @@ class UpdateSessionRequest(BaseModel): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -217,6 +245,10 @@ class UpdateSessionRequest(BaseModel): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None @@ -240,6 +272,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): """ A specific situation that sets the background for this session """ + system_template: str | None = None + """ + System prompt for this session + """ render_templates: StrictBool = True """ Render system and assistant message content as jinja templates @@ -260,6 +296,10 @@ class CreateOrUpdateSessionRequest(CreateSessionRequest): If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. """ + forward_tool_calls: StrictBool = False + """ + Whether to forward tool calls to the model + """ recall_options: RecallOptions | None = None metadata: dict[str, Any] | None = None diff --git a/typespec/sessions/models.tsp b/typespec/sessions/models.tsp index f15453a5f..720625f3b 100644 --- a/typespec/sessions/models.tsp +++ b/typespec/sessions/models.tsp @@ -63,6 +63,9 @@ model Session { /** A specific situation that sets the background for this session */ situation: string = defaultSessionSystemMessage; + /** System prompt for this session */ + system_template: string | null = null; + /** Summary (null at the beginning) - generated automatically after every interaction */ @visibility("read") summary: string | null = null; @@ -83,6 +86,9 @@ model Session { * If a tool call is not made, the model's output will be returned as is. */ auto_run_tools: boolean = false; + /** Whether to forward tool calls to the model */ + forward_tool_calls: boolean = false; + recall_options?: RecallOptions | null = null; ...HasId; diff --git a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml index 9298ab458..d4835a695 100644 --- a/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml +++ b/typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml @@ -3761,10 +3761,12 @@ components: required: - id - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: id: $ref: '#/components/schemas/Common.uuid' @@ -3840,6 +3842,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3865,6 +3872,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -3880,10 +3891,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: user: allOf: @@ -3957,6 +3970,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -3982,6 +4000,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4096,6 +4118,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4121,6 +4148,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4189,11 +4220,13 @@ components: type: object required: - situation + - system_template - summary - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls - id - created_at - updated_at @@ -4254,6 +4287,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null summary: type: string nullable: true @@ -4285,6 +4323,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: @@ -4360,10 +4402,12 @@ components: type: object required: - situation + - system_template - render_templates - token_budget - context_overflow - auto_run_tools + - forward_tool_calls properties: situation: type: string @@ -4421,6 +4465,11 @@ components: {{"---"}} {%- endfor -%} {%- endif -%} + system_template: + type: string + nullable: true + description: System prompt for this session + default: null render_templates: type: boolean description: Render system and assistant message content as jinja templates @@ -4446,6 +4495,10 @@ components: If a tool call is made, the tool's output will be sent back to the model as the model's input. If a tool call is not made, the model's output will be returned as is. default: false + forward_tool_calls: + type: boolean + description: Whether to forward tool calls to the model + default: false recall_options: type: object allOf: