Skip to content

Enable the recommendation bot to search for specific resource details #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions ai_chatbots/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
RemoveMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import ChatPromptTemplate
from langgraph.utils.runnable import RunnableCallable
from langmem.short_term import RunningSummary
from langmem.short_term.summarization import (
DEFAULT_EXISTING_SUMMARY_PROMPT,
DEFAULT_FINAL_SUMMARY_PROMPT,
DEFAULT_INITIAL_SUMMARY_PROMPT,
SummarizationNode,
SummarizationResult,
TokenCounter,
)
Expand Down Expand Up @@ -60,7 +61,9 @@ def get_search_tool_metadata(thread_id: str, latest_state: TypedDict) -> str:
}
return json.dumps(metadata)
except json.JSONDecodeError:
log.exception("Error parsing tool metadata, not valid JSON")
log.exception(
"Error parsing tool metadata, not valid JSON: %s", msg_content
)
return json.dumps(
{"error": "Error parsing tool metadata", "content": msg_content}
)
Expand Down Expand Up @@ -142,7 +145,8 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901
else:
existing_system_message = None

if not messages:
# Summarize only when last message is a human message
if not messages or (messages and not isinstance(messages[-1], HumanMessage)):
return SummarizationResult(
running_summary=running_summary,
messages=(
Expand Down Expand Up @@ -179,7 +183,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901
# map tool call IDs to their corresponding tool messages
tool_call_id_to_tool_message: dict[str, ToolMessage] = {}
should_summarize = False
n_tokens_to_summarize = 0
for i in range(total_summarized_messages, len(messages)):
message = messages[i]
if message.id is None:
Expand All @@ -196,14 +199,16 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901

n_tokens += token_counter([message])

# Check if we've reached max_tokens_to_summarize
# and the remaining messages fit within the max_remaining_tokens budget
# Check if we've reached max_tokens_to_summarize and
# final message is a valid type to end summarization on
# (not a tool message or AI tool)
if (
n_tokens >= max_tokens_before_summary
and total_n_tokens - n_tokens <= max_remaining_tokens
and not should_summarize
and not isinstance(message, ToolMessage)
and (not isinstance(message, AIMessage) or not message.tool_calls)
):
n_tokens_to_summarize = n_tokens
should_summarize = True
idx = i

Expand All @@ -212,21 +217,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901
else:
messages_to_summarize = messages[total_summarized_messages : idx + 1]

# If the last message is an AI message with tool calls,
# include subsequent corresponding tool messages in the summary as well,
# to avoid issues w/ the LLM provider
if (
messages_to_summarize
and isinstance(messages_to_summarize[-1], AIMessage)
and (tool_calls := messages_to_summarize[-1].tool_calls)
):
# Add any matching tool messages from our dictionary
for tool_call in tool_calls:
if tool_call["id"] in tool_call_id_to_tool_message:
tool_message = tool_call_id_to_tool_message[tool_call["id"]]
n_tokens_to_summarize += token_counter([tool_message])
messages_to_summarize.append(tool_message)

if messages_to_summarize:
if running_summary:
summary_messages = cast(
Expand All @@ -245,6 +235,7 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901
)
log.debug("messages to summarize: %s", messages_to_summarize)
summary_response = model.invoke(summary_messages.messages)
log.debug("Summarization response: %s", summary_response.content)
summarized_message_ids = summarized_message_ids | {
message.id for message in messages_to_summarize
}
Expand Down Expand Up @@ -293,13 +284,43 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901
)


class CustomSummarizationNode(SummarizationNode):
class CustomSummarizationNode(RunnableCallable):
"""
Customized implementation of langmem.short_term.SummarizationNode. The original
has a bug causing the most recent user question and answer to be lost when the
summary was updated.
"""

def __init__( # noqa: PLR0913
self,
*,
model: LanguageModelLike,
max_tokens: int,
max_tokens_before_summary: int | None = None,
max_summary_tokens: int = 256,
token_counter: TokenCounter = count_tokens_approximately,
initial_summary_prompt: ChatPromptTemplate = DEFAULT_INITIAL_SUMMARY_PROMPT,
existing_summary_prompt: ChatPromptTemplate = DEFAULT_EXISTING_SUMMARY_PROMPT,
final_prompt: ChatPromptTemplate = DEFAULT_FINAL_SUMMARY_PROMPT,
input_messages_key: str = "messages",
output_messages_key: str = "summarized_messages",
name: str = "summarization",
) -> None:
"""
Initialize the CustomSummarizationNode.
"""
super().__init__(self._func, name=name, trace=False)
self.model = model
self.max_tokens = max_tokens
self.max_tokens_before_summary = max_tokens_before_summary
self.max_summary_tokens = max_summary_tokens
self.token_counter = token_counter
self.initial_summary_prompt = initial_summary_prompt
self.existing_summary_prompt = existing_summary_prompt
self.final_prompt = final_prompt
self.input_messages_key = input_messages_key
self.output_messages_key = output_messages_key

def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]:
"""
Generate a summary if needed.
Expand Down
37 changes: 26 additions & 11 deletions ai_chatbots/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,9 @@ def test_subsequent_summarization_with_new_messages_approximate_token_counter():
assert len(updated_summary_value.summarized_message_ids) == len(messages2) - 3


def test_last_ai_with_tool_calls():
@pytest.mark.parametrize("is_tool_call", [True, False])
def test_last_ai_with_tool_calls(is_tool_call):
"""Summarization should be skipped if last message is a tool call."""
model = MockChatModel(responses=[AIMessage(content="Summary without tool calls.")])

messages = [
Expand All @@ -584,37 +586,48 @@ def test_last_ai_with_tool_calls():
HumanMessage(content="Message 2", id="6"),
]

if is_tool_call:
# If the last message is a tool call, we should not summarize
messages.append(
ToolMessage(content="Call tool 3", tool_call_id="3", name="tool_3", id="7")
)

# Call the summarizer
result = summarize_messages(
messages,
running_summary=None,
model=model,
token_counter=len,
max_tokens_before_summary=2,
max_tokens=6,
max_tokens=2,
max_summary_tokens=1,
)

# Check that the AI message with tool calls was summarized together with the tool messages
assert len(result.messages) == 3
assert result.messages[0].type == "system" # Summary
assert result.messages[-2:] == messages[-2:]
assert result.running_summary.summarized_message_ids == {
msg.id for msg in messages[:-2]
}
assert len(result.messages) == (7 if is_tool_call else 2)
assert result.messages[0].type == ("human" if is_tool_call else "system")
assert result.messages[-1].type == ("tool" if is_tool_call else "human")

if is_tool_call:
assert result.running_summary is None
else:
assert result.running_summary.summarized_message_ids == {
msg.id for msg in messages[:-1]
}


def test_missing_message_ids():
messages = [
HumanMessage(content="Message 1", id="1"),
AIMessage(content="Response"), # Missing ID
HumanMessage(content="Message 2", id="1"),
]
with pytest.raises(ValueError, match="Messages are required to have ID field"):
summarize_messages(
messages,
running_summary=None,
model=MockChatModel(responses=[]),
max_tokens=10,
max_tokens=1,
max_summary_tokens=1,
)

Expand All @@ -640,8 +653,10 @@ def test_duplicate_message_ids():

# Second summarization with a duplicate ID
messages2 = [
AIMessage(content="Response 1", id="2"), # Duplicate ID
HumanMessage(content="Message 2", id="3"), # Duplicate ID
AIMessage(content="Response 2", id="4"),
HumanMessage(content="Message 3", id="1"), # Duplicate ID
HumanMessage(content="Message 3", id="5"),
]

with pytest.raises(ValueError, match="has already been summarized"):
Expand All @@ -650,7 +665,7 @@ def test_duplicate_message_ids():
running_summary=result.running_summary,
model=model,
token_counter=len,
max_tokens=5,
max_tokens=6,
max_summary_tokens=1,
)

Expand Down
7 changes: 3 additions & 4 deletions ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from ai_chatbots.api import CustomSummarizationNode, get_search_tool_metadata
from ai_chatbots.models import TutorBotOutput
from ai_chatbots.prompts import PROMPT_MAPPING
from ai_chatbots.tools import get_video_transcript_chunk, search_content_files
from ai_chatbots.utils import get_django_cache

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -405,7 +404,7 @@ def __init__( # noqa: PLR0913

def create_tools(self) -> list[BaseTool]:
"""Create tools required by the agent"""
return [tools.search_courses]
return [tools.search_courses, tools.search_content_files]

async def get_tool_metadata(self) -> str:
"""Return the metadata for the search tool"""
Expand Down Expand Up @@ -457,7 +456,7 @@ def __init__( # noqa: PLR0913

def create_tools(self):
"""Create tools required by the agent"""
return [search_content_files]
return [tools.search_content_files]

async def get_tool_metadata(self) -> str:
"""Return the metadata for the search tool"""
Expand Down Expand Up @@ -687,7 +686,7 @@ def __init__( # noqa: PLR0913

def create_tools(self):
"""Create tools required for the agent"""
return [get_video_transcript_chunk]
return [tools.get_video_transcript_chunk]

async def get_tool_metadata(self) -> str:
"""Return the metadata for the search tool"""
Expand Down
1 change: 1 addition & 0 deletions ai_chatbots/chatbots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def test_recommendation_bot_tool(settings, mocker, search_results):
retained_attributes = [
"title",
"id",
"readable_id",
"description",
"offered_by",
"free",
Expand Down
28 changes: 13 additions & 15 deletions ai_chatbots/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@

Your job:
1. Understand the user's intent AND BACKGROUND based on their message.
2. Use the available function to gather information or recommend courses.
2. Use the available tools to gather information or recommend courses.
3. Provide a clear, user-friendly explanation of your recommendations if search results
are found.


Run the tool to find learning resources that the user is interested in,
and answer only based on the function search
results.

VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO
ANSWER QUESTIONS.
Run the "search_courses" tool to find learning resources that the user is interested in,
and answer only based on the function search results. If the user asks for more
specific information about a particular resource, use the "search_content_files" tool
to find an answer.

If no results are returned, say you could not find any relevant
resources. Don't say you're going to try again. Ask the user if they would like to
Expand Down Expand Up @@ -48,10 +46,7 @@
Expected Output: Maybe ask whether the user wants to learn how to program, or just use
AI in their discipline - does this person want to study machine learning? More info
needed. Then perform a relevant search and send back the best results.


AGAIN: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO
ANSWER QUESTIONS."""
"""


PROMPT_SYLLABUS = """You are an assistant named Tim, helping users answer questions
Expand All @@ -64,7 +59,8 @@
answer the user's question.

Always use the tool results to answer questions, and answer only based on the tool
output. Do not include the course id in the query parameter.
output. Do not include the course_id in the query parameter. The tool always has
access to the course id.
VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO
ANSWER QUESTIONS. If no results are returned, say you could not find any relevant
information."""
Expand All @@ -87,9 +83,11 @@

# The following prompts are similar or identical to the default ones in
# langmem.short_term.summarization
PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above.
If there are any tool results, include the full output of the latest one in
the summary.
PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above, incorporating
thevprevious summary if any.
If there are any tool results, include the full output of the latest tool message in
the summary. You must also retain all title and readable_id field values from all tool
messages and any previous summaries in this new summary.
"""
PROMPT_SUMMARY_EXISTING = """This is summary of the conversation so far:
{existing_summary}
Expand Down
Loading
Loading