From 148b3b72e358a9c2b75714359ce879052b82ab71 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 27 Dec 2024 09:56:23 +0300 Subject: [PATCH] fix: Apply various fixes to chat routes --- .../agents_api/queries/chat/gather_messages.py | 6 ++++++ .../agents_api/queries/chat/prepare_chat_context.py | 13 ++++++++++--- .../agents_api/queries/sessions/create_session.py | 2 +- agents-api/agents_api/routers/sessions/chat.py | 1 - agents-api/tests/test_chat_routes.py | 13 ++++++++----- memory-store/migrations/000009_sessions.up.sql | 2 +- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/agents-api/agents_api/queries/chat/gather_messages.py b/agents-api/agents_api/queries/chat/gather_messages.py index fb3205acf..dd3c08439 100644 --- a/agents-api/agents_api/queries/chat/gather_messages.py +++ b/agents-api/agents_api/queries/chat/gather_messages.py @@ -35,6 +35,7 @@ async def gather_messages( session_id: UUID, chat_context: ChatContext, chat_input: ChatInput, + connection_pool=None, ) -> tuple[list[dict], list[DocReference]]: new_raw_messages = [msg.model_dump(mode="json") for msg in chat_input.messages] recall = chat_input.recall @@ -46,6 +47,7 @@ async def gather_messages( developer_id=developer.id, session_id=session_id, allowed_sources=["api_request", "api_response", "tool_response", "summarizer"], + connection_pool=connection_pool, ) # Keep leaf nodes only @@ -72,6 +74,7 @@ async def gather_messages( session: Session = await get_session( developer_id=developer.id, session_id=session_id, + connection_pool=connection_pool, ) recall_options = session.recall_options @@ -121,6 +124,7 @@ async def gather_messages( developer_id=developer.id, owners=owners, query_embedding=query_embedding, + connection_pool=connection_pool, ) case "hybrid": doc_references: list[DocReference] = await search_docs_hybrid( @@ -128,12 +132,14 @@ async def gather_messages( owners=owners, query=query_text, query_embedding=query_embedding, + connection_pool=connection_pool, ) case "text": doc_references: list[DocReference] = await search_docs_by_text( developer_id=developer.id, owners=owners, query=query_text, + connection_pool=connection_pool, ) return past_messages, doc_references diff --git a/agents-api/agents_api/queries/chat/prepare_chat_context.py b/agents-api/agents_api/queries/chat/prepare_chat_context.py index e56e66abe..ccd4052fa 100644 --- a/agents-api/agents_api/queries/chat/prepare_chat_context.py +++ b/agents-api/agents_api/queries/chat/prepare_chat_context.py @@ -8,6 +8,7 @@ pg_query, wrap_in_class, ) +from ...common.utils.datetime import utcnow ModelT = TypeVar("ModelT", bound=Any) T = TypeVar("T") @@ -110,18 +111,24 @@ def _transform(d): d["users"] = d.get("users") or [] d["agents"] = d.get("agents") or [] - for tool in d.get("toolsets") or []: + for tool in d.get("toolsets", []) or []: + if not tool: + continue + agent_id = tool["agent_id"] if agent_id in toolsets: toolsets[agent_id].append(tool) else: toolsets[agent_id] = [tool] + + d["session"]["updated_at"] = utcnow() + d["users"] = d.get("users", []) or [] transformed_data = { **d, "session": make_session( - agents=[a["id"] for a in d.get("agents") or []], - users=[u["id"] for u in d.get("users") or []], + agents=[a["id"] for a in d.get("agents", []) or []], + users=[u["id"] for u in d.get("users", []) or []], **d["session"], ), "toolsets": [ diff --git a/agents-api/agents_api/queries/sessions/create_session.py b/agents-api/agents_api/queries/sessions/create_session.py index edfe9e1bb..b7196459a 100644 --- a/agents-api/agents_api/queries/sessions/create_session.py +++ b/agents-api/agents_api/queries/sessions/create_session.py @@ -138,7 +138,7 @@ async def create_session( data.token_budget, # $7 data.context_overflow, # $8 data.forward_tool_calls, # $9 - data.recall_options or {}, # $10 + data.recall_options.model_dump() if data.recall_options else {}, # $10 ] # Prepare lookup parameters as a list of parameter lists diff --git a/agents-api/agents_api/routers/sessions/chat.py b/agents-api/agents_api/routers/sessions/chat.py index 2fc5a859e..b5ded8522 100644 --- a/agents-api/agents_api/routers/sessions/chat.py +++ b/agents-api/agents_api/routers/sessions/chat.py @@ -219,7 +219,6 @@ async def chat( developer_id=developer.id, session_id=session_id, data=new_entries, - mark_session_as_updated=True, ) # Adaptive context handling diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index 5ba06eb80..d03e2e30a 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -46,7 +46,7 @@ async def _( (embed, _) = mocks pool = await create_db_pool(dsn=dsn) - chat_context = prepare_chat_context( + chat_context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, connection_pool=pool, @@ -61,6 +61,7 @@ async def _( session_id=session_id, chat_context=chat_context, chat_input=ChatInput(messages=messages, recall=False), + connection_pool=pool, ) assert isinstance(past_messages, list) @@ -84,7 +85,7 @@ async def _( mocks=patch_embed_acompletion, ): pool = await create_db_pool(dsn=dsn) - session = create_session( + session = await create_session( developer_id=developer_id, data=CreateSessionRequest( agent=agent.id, @@ -100,7 +101,7 @@ async def _( (embed, _) = mocks - chat_context = prepare_chat_context( + chat_context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, connection_pool=pool, @@ -115,6 +116,7 @@ async def _( session_id=session_id, chat_context=chat_context, chat_input=ChatInput(messages=messages, recall=True), + connection_pool=pool, ) assert isinstance(past_messages, list) @@ -133,7 +135,7 @@ async def _( dsn=pg_dsn, ): pool = await create_db_pool(dsn=dsn) - session = create_session( + session = await create_session( developer_id=developer_id, data=CreateSessionRequest( agent=agent.id, @@ -172,11 +174,12 @@ async def _( user=test_user, ): pool = await create_db_pool(dsn=dsn) - context = prepare_chat_context( + context = await prepare_chat_context( developer_id=developer_id, session_id=session.id, connection_pool=pool, ) + print("-->", type(context), context) assert isinstance(context, ChatContext) assert len(context.toolsets) > 0 diff --git a/memory-store/migrations/000009_sessions.up.sql b/memory-store/migrations/000009_sessions.up.sql index b5554b26f..5c7a8717b 100644 --- a/memory-store/migrations/000009_sessions.up.sql +++ b/memory-store/migrations/000009_sessions.up.sql @@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS sessions ( developer_id UUID NOT NULL, session_id UUID NOT NULL, situation TEXT, - system_template TEXT NOT NULL, + system_template TEXT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, metadata JSONB NOT NULL DEFAULT '{}'::JSONB,