Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Added mmr search and get history system tool + configurable doc search params in chat.py #940

Merged
merged 11 commits into from
Dec 12, 2024
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from beartype import beartype
from box import Box, BoxList
from fastapi import HTTPException
from fastapi.background import BackgroundTasks
from temporalio import activity

Expand Down Expand Up @@ -109,7 +108,8 @@ async def execute_system(
)
await bg_runner()
return res


# Handle create operations
if system.operation == "create" and system.resource == "session":
developer_id = arguments.pop("developer_id")
session_id = arguments.pop("session_id", None)
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def get_handler(system: SystemDef) -> Callable:
from ..models.session.get_session import get_session as get_session_query
from ..models.session.list_sessions import list_sessions as list_sessions_query
from ..models.session.update_session import update_session as update_session_query
from ..models.entry.get_history import get_history as get_history_query
from ..models.task.create_task import create_task as create_task_query
from ..models.task.delete_task import delete_task as delete_task_query
from ..models.task.get_task import get_task as get_task_query
Expand Down Expand Up @@ -376,6 +377,8 @@ def get_handler(system: SystemDef) -> Callable:
return delete_session_query
case ("session", None, "chat"):
return chat
case ("session", None, "history"):
return get_history_query

# TASKS
case ("task", None, "list"):
Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ class BaseDocSearchRequest(BaseModel):
populate_by_name=True,
)
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
Expand Down
35 changes: 33 additions & 2 deletions agents-api/agents_api/autogen/Sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,41 @@ class RecallOptions(BaseModel):
populate_by_name=True,
)
mode: Literal["hybrid", "vector", "text"] = "vector"
"""
The mode to use for the search.
"""
num_search_messages: int = 4
"""
The number of search messages to use for the search.
"""
max_query_length: int = 1000
hybrid_alpha: float = 0.7
confidence: float = 0.6
"""
The maximum query length to use for the search.
"""
alpha: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
"""
The weight to apply to BM25 vs Vector search results. 0 => pure BM25; 1 => pure vector;
"""
confidence: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6
"""
The confidence cutoff level
"""
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
"""


class RecallOptionsUpdate(RecallOptions):
Expand Down
129 changes: 98 additions & 31 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from uuid import UUID

import numpy as np
from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import ChatInput, DocReference, History
from ...autogen.Sessions import RecallOptions
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...models.docs.mmr import maximal_marginal_relevance
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand All @@ -23,6 +26,52 @@
T = TypeVar("T")


def get_search_fn_and_params(
recall_options: RecallOptions,
query_text: str | None,
query_embedding: list[float] | None,
) -> Tuple[
Any,
Optional[Dict[str, Union[float, int, str, Dict[str, float], List[float], None]]],
]:
search_fn, params = None, None

match recall_options.mode:
case "text":
search_fn = search_docs_by_text
params = dict(
query=query_text,
k=recall_options.limit,
metadata_filter=recall_options.metadata_filter,
)

case "vector":
search_fn = search_docs_by_embedding
params = dict(
query_embedding=query_embedding,
k=recall_options.limit * 3
if recall_options.mmr_strength > 0
Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
else recall_options.limit,
confidence=recall_options.confidence,
metadata_filter=recall_options.metadata_filter,
)

case "hybrid":
search_fn = search_docs_hybrid
params = dict(
query=query_text,
query_embedding=query_embedding,
k=recall_options.limit * 3
if recall_options.mmr_strength > 0
else recall_options.limit,
embed_search_options=dict(confidence=recall_options.confidence),
alpha=recall_options.alpha,
metadata_filter=recall_options.metadata_filter,
)

Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
return search_fn, params


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
Expand Down Expand Up @@ -98,44 +147,62 @@ async def gather_messages(
]
).strip()

[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)
# Set the query text and embedding
query_text, query_embedding = None, None

# Embed the query
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)

# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[
: recall_options.max_query_length
]
if recall_options.mode == "text" or recall_options.mode == "hybrid":
query_text = search_messages[-1]["content"].strip()[
: recall_options.max_query_length
]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Get the search function and parameters
search_fn, params = get_search_fn_and_params(
recall_options=recall_options,
query_text=query_text,
query_embedding=query_embedding,
)

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references: list[DocReference] = search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
)
case "hybrid":
doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)
case "text":
doc_references: list[DocReference] = search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
)
doc_references: list[DocReference] = search_fn(
developer_id=developer.id,
owners=owners,
**params,
)

# Apply MMR if enabled
if (
# MMR is enabled
recall_options.mmr_strength > 0
# The number of doc references is greater than the limit
and len(doc_references) > recall_options.limit
# MMR is not applied to text search
and recall_options.mode != "text"
):
# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
# Apply MMR
doc_references = [
doc for i, doc in enumerate(doc_references) if i in set(indices)
]

# Return the past messages and doc references
return past_messages, doc_references
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...env import (
debug,
temporal_activity_after_retry_timeout,
temporal_heartbeat_timeout,
temporal_schedule_to_close_timeout,
testing,
Expand Down
17 changes: 12 additions & 5 deletions agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ async def _(
agent=agent.id,
situation="test session about",
recall_options={
"mode": "text",
"num_search_messages": 10,
"max_query_length": 1001,
"mode": "hybrid",
"num_search_messages": 6,
"max_query_length": 800,
"confidence": 0.6,
"alpha": 0.7,
"limit": 10,
"mmr_strength": 0.5,
},
),
client=client,
Expand Down Expand Up @@ -135,9 +139,12 @@ async def _(
agent=agent.id,
situation="test session about",
recall_options={
"mode": "vector",
"num_search_messages": 5,
"mode": "text",
"num_search_messages": 10,
"max_query_length": 1001,
"confidence": 0.6,
"limit": 5,
"mmr_strength": 0.5,
},
),
client=client,
Expand Down
1 change: 0 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from agents_api.models.task.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution
from tests.fixtures import (
async_cozo_client,
cozo_client,
cozo_clients_with_migrations,
test_agent,
Expand Down
6 changes: 6 additions & 0 deletions integrations-service/integrations/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ class BaseDocSearchRequest(BaseModel):
populate_by_name=True,
)
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
Expand Down
35 changes: 33 additions & 2 deletions integrations-service/integrations/autogen/Sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,41 @@ class RecallOptions(BaseModel):
populate_by_name=True,
)
mode: Literal["hybrid", "vector", "text"] = "vector"
"""
The mode to use for the search.
"""
num_search_messages: int = 4
"""
The number of search messages to use for the search.
"""
max_query_length: int = 1000
hybrid_alpha: float = 0.7
confidence: float = 0.6
"""
The maximum query length to use for the search.
"""
alpha: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
"""
The weight to apply to BM25 vs Vector search results. 0 => pure BM25; 1 => pure vector;
"""
confidence: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6
"""
The confidence cutoff level
"""
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
"""


class RecallOptionsUpdate(RecallOptions):
Expand Down
5 changes: 4 additions & 1 deletion typespec/docs/models.tsp
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,16 @@ model EmbedQueryResponse {
}

model BaseDocSearchRequest {
/** The limit of documents to return */
@minValue(1)
@maxValue(50)
limit: uint16 = 10;

/** The language to be used for text-only search. Support for other languages coming soon. */
lang: "en-US" = "en-US";
metadata_filter: MetadataFilter = #{},

/** Metadata filter to apply to the search */
metadata_filter: MetadataFilter = #{};

/** MMR Strength (mmr_strength = 1 - mmr_lambda) */
@minValue(0)
Expand Down
Loading
Loading