From ea378374e0e8d432467fea225617c3e06497e0ad Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Fri, 27 Dec 2024 16:25:57 +0300 Subject: [PATCH] fix: Apply small fixes to docs logic --- .../agents_api/routers/docs/search_docs.py | 16 ++++++++-------- agents-api/tests/fixtures.py | 4 +++- agents-api/tests/test_chat_routes.py | 1 - 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/agents-api/agents_api/routers/docs/search_docs.py b/agents-api/agents_api/routers/docs/search_docs.py index ead9e1edb..de385690f 100644 --- a/agents-api/agents_api/routers/docs/search_docs.py +++ b/agents-api/agents_api/routers/docs/search_docs.py @@ -31,7 +31,7 @@ async def get_search_fn_and_params( case TextOnlyDocSearchRequest( text=query, limit=k, metadata_filter=metadata_filter ): - search_fn = await search_docs_by_text + search_fn = search_docs_by_text params = dict( query=query, k=k, @@ -44,7 +44,7 @@ async def get_search_fn_and_params( confidence=confidence, metadata_filter=metadata_filter, ): - search_fn = await search_docs_by_embedding + search_fn = search_docs_by_embedding params = dict( query_embedding=query_embedding, k=k * 3 if search_params.mmr_strength > 0 else k, @@ -60,12 +60,12 @@ async def get_search_fn_and_params( alpha=alpha, metadata_filter=metadata_filter, ): - search_fn = await search_docs_hybrid + search_fn = search_docs_hybrid params = dict( - query=query, - query_embedding=query_embedding, + text_query=query, + embedding=query_embedding, k=k * 3 if search_params.mmr_strength > 0 else k, - embed_search_options=dict(confidence=confidence), + confidence=confidence, alpha=alpha, metadata_filter=metadata_filter, ) @@ -97,7 +97,7 @@ async def search_user_docs( search_fn, params = await get_search_fn_and_params(search_params) start = time.time() - docs: list[DocReference] = search_fn( + docs: list[DocReference] = await search_fn( developer_id=x_developer_id, owners=[("user", user_id)], **params, @@ -148,7 +148,7 @@ async def search_agent_docs( search_fn, params = await get_search_fn_and_params(search_params) start = time.time() - docs: list[DocReference] = search_fn( + docs: list[DocReference] = await search_fn( developer_id=x_developer_id, owners=[("agent", agent_id)], **params, diff --git a/agents-api/tests/fixtures.py b/agents-api/tests/fixtures.py index 14daea854..9a5bbb058 100644 --- a/agents-api/tests/fixtures.py +++ b/agents-api/tests/fixtures.py @@ -27,6 +27,7 @@ from agents_api.queries.developers.create_developer import create_developer from agents_api.queries.developers.get_developer import get_developer from agents_api.queries.docs.create_doc import create_doc +from agents_api.queries.docs.get_doc import get_doc from agents_api.queries.executions.create_execution import create_execution from agents_api.queries.executions.create_execution_transition import ( create_execution_transition, @@ -135,7 +136,7 @@ async def test_file(dsn=pg_dsn, developer=test_developer, user=test_user): @fixture(scope="test") async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): pool = await create_db_pool(dsn=dsn) - doc = await create_doc( + resp = await create_doc( developer_id=developer.id, data=CreateDocRequest( title="Hello", @@ -147,6 +148,7 @@ async def test_doc(dsn=pg_dsn, developer=test_developer, agent=test_agent): owner_id=agent.id, connection_pool=pool, ) + doc = await get_doc(developer_id=developer.id, doc_id=resp.id, connection_pool=pool) return doc diff --git a/agents-api/tests/test_chat_routes.py b/agents-api/tests/test_chat_routes.py index d03e2e30a..d91696c15 100644 --- a/agents-api/tests/test_chat_routes.py +++ b/agents-api/tests/test_chat_routes.py @@ -180,6 +180,5 @@ async def _( connection_pool=pool, ) - print("-->", type(context), context) assert isinstance(context, ChatContext) assert len(context.toolsets) > 0