Skip to content

Commit

Permalink
fix: Apply various fixes to chat routes
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed Dec 27, 2024
1 parent e6abab5 commit 148b3b7
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 11 deletions.
6 changes: 6 additions & 0 deletions agents-api/agents_api/queries/chat/gather_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ async def gather_messages(
session_id: UUID,
chat_context: ChatContext,
chat_input: ChatInput,
connection_pool=None,
) -> tuple[list[dict], list[DocReference]]:
new_raw_messages = [msg.model_dump(mode="json") for msg in chat_input.messages]
recall = chat_input.recall
Expand All @@ -46,6 +47,7 @@ async def gather_messages(
developer_id=developer.id,
session_id=session_id,
allowed_sources=["api_request", "api_response", "tool_response", "summarizer"],
connection_pool=connection_pool,
)

# Keep leaf nodes only
Expand All @@ -72,6 +74,7 @@ async def gather_messages(
session: Session = await get_session(
developer_id=developer.id,
session_id=session_id,
connection_pool=connection_pool,
)
recall_options = session.recall_options

Expand Down Expand Up @@ -121,19 +124,22 @@ async def gather_messages(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
connection_pool=connection_pool,
)
case "hybrid":
doc_references: list[DocReference] = await search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
connection_pool=connection_pool,
)
case "text":
doc_references: list[DocReference] = await search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
connection_pool=connection_pool,
)

return past_messages, doc_references
13 changes: 10 additions & 3 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
pg_query,
wrap_in_class,
)
from ...common.utils.datetime import utcnow

ModelT = TypeVar("ModelT", bound=Any)
T = TypeVar("T")
Expand Down Expand Up @@ -110,18 +111,24 @@ def _transform(d):
d["users"] = d.get("users") or []
d["agents"] = d.get("agents") or []

for tool in d.get("toolsets") or []:
for tool in d.get("toolsets", []) or []:
if not tool:
continue

agent_id = tool["agent_id"]
if agent_id in toolsets:
toolsets[agent_id].append(tool)
else:
toolsets[agent_id] = [tool]

d["session"]["updated_at"] = utcnow()
d["users"] = d.get("users", []) or []

transformed_data = {
**d,
"session": make_session(
agents=[a["id"] for a in d.get("agents") or []],
users=[u["id"] for u in d.get("users") or []],
agents=[a["id"] for a in d.get("agents", []) or []],
users=[u["id"] for u in d.get("users", []) or []],
**d["session"],
),
"toolsets": [
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/sessions/create_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def create_session(
data.token_budget, # $7
data.context_overflow, # $8
data.forward_tool_calls, # $9
data.recall_options or {}, # $10
data.recall_options.model_dump() if data.recall_options else {}, # $10
]

# Prepare lookup parameters as a list of parameter lists
Expand Down
1 change: 0 additions & 1 deletion agents-api/agents_api/routers/sessions/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ async def chat(
developer_id=developer.id,
session_id=session_id,
data=new_entries,
mark_session_as_updated=True,
)

# Adaptive context handling
Expand Down
13 changes: 8 additions & 5 deletions agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def _(
(embed, _) = mocks

pool = await create_db_pool(dsn=dsn)
chat_context = prepare_chat_context(
chat_context = await prepare_chat_context(
developer_id=developer_id,
session_id=session.id,
connection_pool=pool,
Expand All @@ -61,6 +61,7 @@ async def _(
session_id=session_id,
chat_context=chat_context,
chat_input=ChatInput(messages=messages, recall=False),
connection_pool=pool,
)

assert isinstance(past_messages, list)
Expand All @@ -84,7 +85,7 @@ async def _(
mocks=patch_embed_acompletion,
):
pool = await create_db_pool(dsn=dsn)
session = create_session(
session = await create_session(
developer_id=developer_id,
data=CreateSessionRequest(
agent=agent.id,
Expand All @@ -100,7 +101,7 @@ async def _(

(embed, _) = mocks

chat_context = prepare_chat_context(
chat_context = await prepare_chat_context(
developer_id=developer_id,
session_id=session.id,
connection_pool=pool,
Expand All @@ -115,6 +116,7 @@ async def _(
session_id=session_id,
chat_context=chat_context,
chat_input=ChatInput(messages=messages, recall=True),
connection_pool=pool,
)

assert isinstance(past_messages, list)
Expand All @@ -133,7 +135,7 @@ async def _(
dsn=pg_dsn,
):
pool = await create_db_pool(dsn=dsn)
session = create_session(
session = await create_session(
developer_id=developer_id,
data=CreateSessionRequest(
agent=agent.id,
Expand Down Expand Up @@ -172,11 +174,12 @@ async def _(
user=test_user,
):
pool = await create_db_pool(dsn=dsn)
context = prepare_chat_context(
context = await prepare_chat_context(
developer_id=developer_id,
session_id=session.id,
connection_pool=pool,
)

print("-->", type(context), context)
assert isinstance(context, ChatContext)
assert len(context.toolsets) > 0
2 changes: 1 addition & 1 deletion memory-store/migrations/000009_sessions.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CREATE TABLE IF NOT EXISTS sessions (
developer_id UUID NOT NULL,
session_id UUID NOT NULL,
situation TEXT,
system_template TEXT NOT NULL,
system_template TEXT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
metadata JSONB NOT NULL DEFAULT '{}'::JSONB,
Expand Down

0 comments on commit 148b3b7

Please sign in to comment.