Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(agents-api): Added mmr search and get history system tool + configurable doc search params in chat.py #940

Merged
merged 11 commits into from
Dec 12, 2024
Merged
81 changes: 81 additions & 0 deletions .github/workflows/generate-changelog.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
name: Julep-Changelog-Generation
run-name: ${{ github.actor }} is generating changelog for the last two weeks using Julep

on:
workflow_dispatch:

jobs:
changelog_generation:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: dev

- name: Setup GitHub CLI
run: |
echo "${{ secrets.GITHUB_TOKEN }}" | gh auth login --with-token

- name: Collect merged PRs from the last two weeks
id: collect_prs
run: |
# Set date threshold for fetching PRs
if [[ "$OSTYPE" == "darwin"* ]]; then
date_threshold=$(date -v-14d +"%Y-%m-%d")
else
date_threshold=$(date -d '-14 days' +"%Y-%m-%d")
fi

echo "Fetching merged PRs since $date_threshold..."

# Find merged PRs from the last two weeks
merged_prs=$(gh pr list --state merged --json number,title,body,author --search "merged:>=$date_threshold" --jq 'map({number, title, body, author: .author.login})')

if [ -z "$merged_prs" ] || [ "$merged_prs" = "null" ]; then
echo "No merged PRs found in the last two weeks."
echo "pr_data=[]" >> $GITHUB_ENV
echo '{"pr_data": []}' > pr_data.json
exit 0
fi

echo "pr_data=$merged_prs" >> $GITHUB_ENV
echo "pr_data=$merged_prs" >> "$GITHUB_OUTPUT"
echo "{\"pr_data\": $merged_prs}" > pr_data.json

- name: Setup Python v3.10.12
uses: actions/setup-python@v5
with:
python-version: "3.10.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install PyYAML julep git+https://github.com/Jwink3101/parmapper

- name: Send PR data to Python script
if: steps.collect_prs.outputs.pr_data != '[]'
id: generate_changelog
run: |
if ! python scripts/generate_changelog.py; then
Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
echo "Error: Failed to generate changelog"
exit 1
fi
env:
JULEP_API_KEY: ${{ secrets.JULEP_API_KEY }}
TASK_UUID: ${{ secrets.TASK_UUID }}
AGENT_UUID: ${{ secrets.AGENT_UUID }}

- name: Create Pull Request
if: success() && steps.collect_prs.outputs.pr_data != '[]'
uses: peter-evans/create-pull-request@v7
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: "chore(changelog): update CHANGELOG.md"
title: "Update CHANGELOG.md"
body: "This PR updates the changelog with PRs from the last two weeks."
branch: "update-changelog"
delete-branch: true
add-paths: |
CHANGELOG.md
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from beartype import beartype
from box import Box, BoxList
from fastapi import HTTPException
from fastapi.background import BackgroundTasks
from temporalio import activity

Expand Down Expand Up @@ -110,6 +109,7 @@ async def execute_system(
await bg_runner()
return res

# Handle create operations
if system.operation == "create" and system.resource == "session":
developer_id = arguments.pop("developer_id")
session_id = arguments.pop("session_id", None)
Expand Down
3 changes: 3 additions & 0 deletions agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_handler(system: SystemDef) -> Callable:
from ..models.agent.update_agent import update_agent as update_agent_query
from ..models.docs.delete_doc import delete_doc as delete_doc_query
from ..models.docs.list_docs import list_docs as list_docs_query
from ..models.entry.get_history import get_history as get_history_query
from ..models.session.create_session import create_session as create_session_query
from ..models.session.delete_session import delete_session as delete_session_query
from ..models.session.get_session import get_session as get_session_query
Expand Down Expand Up @@ -376,6 +377,8 @@ def get_handler(system: SystemDef) -> Callable:
return delete_session_query
case ("session", None, "chat"):
return chat
case ("session", None, "history"):
return get_history_query

# TASKS
case ("task", None, "list"):
Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/autogen/Docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ class BaseDocSearchRequest(BaseModel):
populate_by_name=True,
)
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
Expand Down
35 changes: 33 additions & 2 deletions agents-api/agents_api/autogen/Sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,41 @@ class RecallOptions(BaseModel):
populate_by_name=True,
)
mode: Literal["hybrid", "vector", "text"] = "vector"
"""
The mode to use for the search.
"""
num_search_messages: int = 4
"""
The number of search messages to use for the search.
"""
max_query_length: int = 1000
hybrid_alpha: float = 0.7
confidence: float = 0.6
"""
The maximum query length to use for the search.
"""
alpha: Annotated[float, Field(ge=0.0, le=1.0)] = 0.7
"""
The weight to apply to BM25 vs Vector search results. 0 => pure BM25; 1 => pure vector;
"""
confidence: Annotated[float, Field(ge=0.0, le=1.0)] = 0.6
"""
The confidence cutoff level
"""
limit: Annotated[int, Field(ge=1, le=50)] = 10
"""
The limit of documents to return
"""
lang: Literal["en-US"] = "en-US"
"""
The language to be used for text-only search. Support for other languages coming soon.
"""
metadata_filter: dict[str, Any] = {}
"""
Metadata filter to apply to the search
"""
mmr_strength: Annotated[float, Field(ge=0.0, lt=1.0)] = 0
"""
MMR Strength (mmr_strength = 1 - mmr_lambda)
"""


class RecallOptionsUpdate(RecallOptions):
Expand Down
129 changes: 98 additions & 31 deletions agents-api/agents_api/models/chat/gather_messages.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import TypeVar
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from uuid import UUID

import numpy as np
from beartype import beartype
from fastapi import HTTPException
from pycozo.client import QueryException
from pydantic import ValidationError

from ...autogen.openapi_model import ChatInput, DocReference, History
from ...autogen.Sessions import RecallOptions
from ...clients import litellm
from ...common.protocol.developers import Developer
from ...common.protocol.sessions import ChatContext
from ...models.docs.mmr import maximal_marginal_relevance
from ..docs.search_docs_by_embedding import search_docs_by_embedding
from ..docs.search_docs_by_text import search_docs_by_text
from ..docs.search_docs_hybrid import search_docs_hybrid
Expand All @@ -23,6 +26,52 @@
T = TypeVar("T")


def get_search_fn_and_params(
recall_options: RecallOptions,
query_text: str | None,
query_embedding: list[float] | None,
) -> Tuple[
Any,
Optional[Dict[str, Union[float, int, str, Dict[str, float], List[float], None]]],
]:
search_fn, params = None, None

match recall_options.mode:
case "text":
search_fn = search_docs_by_text
params = dict(
query=query_text,
k=recall_options.limit,
metadata_filter=recall_options.metadata_filter,
)

case "vector":
search_fn = search_docs_by_embedding
params = dict(
query_embedding=query_embedding,
k=recall_options.limit * 3
if recall_options.mmr_strength > 0
Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
else recall_options.limit,
confidence=recall_options.confidence,
metadata_filter=recall_options.metadata_filter,
)

case "hybrid":
search_fn = search_docs_hybrid
params = dict(
query=query_text,
query_embedding=query_embedding,
k=recall_options.limit * 3
if recall_options.mmr_strength > 0
else recall_options.limit,
embed_search_options=dict(confidence=recall_options.confidence),
alpha=recall_options.alpha,
metadata_filter=recall_options.metadata_filter,
)

Vedantsahai18 marked this conversation as resolved.
Show resolved Hide resolved
return search_fn, params


@rewrap_exceptions(
{
QueryException: partialclass(HTTPException, status_code=400),
Expand Down Expand Up @@ -98,44 +147,62 @@ async def gather_messages(
]
).strip()

[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)
# Set the query text and embedding
query_text, query_embedding = None, None

# Embed the query
if recall_options.mode != "text":
[query_embedding, *_] = await litellm.aembedding(
# Truncate on the left to keep the last `search_query_chars` characters
inputs=embed_text[-(recall_options.max_query_length) :],
# TODO: Make this configurable once it's added to the ChatInput model
embed_instruction="Represent the query for retrieving supporting documents: ",
)

# Truncate on the right to take only the first `search_query_chars` characters
query_text = search_messages[-1]["content"].strip()[
: recall_options.max_query_length
]
if recall_options.mode == "text" or recall_options.mode == "hybrid":
query_text = search_messages[-1]["content"].strip()[
: recall_options.max_query_length
]

# List all the applicable owners to search docs from
active_agent_id = chat_context.get_active_agent().id
user_ids = [user.id for user in chat_context.users]
owners = [("user", user_id) for user_id in user_ids] + [("agent", active_agent_id)]

# Get the search function and parameters
search_fn, params = get_search_fn_and_params(
recall_options=recall_options,
query_text=query_text,
query_embedding=query_embedding,
)

# Search for doc references
doc_references: list[DocReference] = []
match recall_options.mode:
case "vector":
doc_references: list[DocReference] = search_docs_by_embedding(
developer_id=developer.id,
owners=owners,
query_embedding=query_embedding,
)
case "hybrid":
doc_references: list[DocReference] = search_docs_hybrid(
developer_id=developer.id,
owners=owners,
query=query_text,
query_embedding=query_embedding,
)
case "text":
doc_references: list[DocReference] = search_docs_by_text(
developer_id=developer.id,
owners=owners,
query=query_text,
)
doc_references: list[DocReference] = search_fn(
developer_id=developer.id,
owners=owners,
**params,
)

# Apply MMR if enabled
if (
# MMR is enabled
recall_options.mmr_strength > 0
# The number of doc references is greater than the limit
and len(doc_references) > recall_options.limit
# MMR is not applied to text search
and recall_options.mode != "text"
):
# Apply MMR
indices = maximal_marginal_relevance(
np.asarray(query_embedding),
[doc.snippet.embedding for doc in doc_references],
k=recall_options.limit,
)
# Apply MMR
doc_references = [
doc for i, doc in enumerate(doc_references) if i in set(indices)
]

# Return the past messages and doc references
return past_messages, doc_references
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from ...common.retry_policies import DEFAULT_RETRY_POLICY
from ...env import (
debug,
temporal_activity_after_retry_timeout,
temporal_heartbeat_timeout,
temporal_schedule_to_close_timeout,
testing,
Expand Down
17 changes: 12 additions & 5 deletions agents-api/tests/test_chat_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,13 @@ async def _(
agent=agent.id,
situation="test session about",
recall_options={
"mode": "text",
"num_search_messages": 10,
"max_query_length": 1001,
"mode": "hybrid",
"num_search_messages": 6,
"max_query_length": 800,
"confidence": 0.6,
"alpha": 0.7,
"limit": 10,
"mmr_strength": 0.5,
},
),
client=client,
Expand Down Expand Up @@ -135,9 +139,12 @@ async def _(
agent=agent.id,
situation="test session about",
recall_options={
"mode": "vector",
"num_search_messages": 5,
"mode": "text",
"num_search_messages": 10,
"max_query_length": 1001,
"confidence": 0.6,
"limit": 5,
"mmr_strength": 0.5,
},
),
client=client,
Expand Down
1 change: 0 additions & 1 deletion agents-api/tests/test_execution_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from agents_api.models.task.create_task import create_task
from agents_api.routers.tasks.create_task_execution import start_execution
from tests.fixtures import (
async_cozo_client,
cozo_client,
cozo_clients_with_migrations,
test_agent,
Expand Down
Loading
Loading