From 14f1410d7368d574c0ccb6c5ae86149467909e08 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Thu, 29 May 2025 14:40:01 -0400 Subject: [PATCH 01/11] Initial approach, multiple tools in recommendation bot --- ai_chatbots/api.py | 4 +++- ai_chatbots/chatbots.py | 7 +++---- ai_chatbots/prompts.py | 20 ++++++++------------ ai_chatbots/tools.py | 25 ++++++++++++++++++------- 4 files changed, 32 insertions(+), 24 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 92b7d4e2..43d91ef9 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -60,7 +60,9 @@ def get_search_tool_metadata(thread_id: str, latest_state: TypedDict) -> str: } return json.dumps(metadata) except json.JSONDecodeError: - log.exception("Error parsing tool metadata, not valid JSON") + log.exception( + "Error parsing tool metadata, not valid JSON: %s", msg_content + ) return json.dumps( {"error": "Error parsing tool metadata", "content": msg_content} ) diff --git a/ai_chatbots/chatbots.py b/ai_chatbots/chatbots.py index 4d12b536..a94edbaa 100644 --- a/ai_chatbots/chatbots.py +++ b/ai_chatbots/chatbots.py @@ -41,7 +41,6 @@ from ai_chatbots.api import CustomSummarizationNode, get_search_tool_metadata from ai_chatbots.models import TutorBotOutput from ai_chatbots.prompts import PROMPT_MAPPING -from ai_chatbots.tools import get_video_transcript_chunk, search_content_files from ai_chatbots.utils import get_django_cache log = logging.getLogger(__name__) @@ -405,7 +404,7 @@ def __init__( # noqa: PLR0913 def create_tools(self) -> list[BaseTool]: """Create tools required by the agent""" - return [tools.search_courses] + return [tools.search_courses, tools.search_content_files] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" @@ -457,7 +456,7 @@ def __init__( # noqa: PLR0913 def create_tools(self): """Create tools required by the agent""" - return [search_content_files] + return [tools.search_content_files] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" @@ -687,7 +686,7 @@ def __init__( # noqa: PLR0913 def create_tools(self): """Create tools required for the agent""" - return [get_video_transcript_chunk] + return [tools.get_video_transcript_chunk] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" diff --git a/ai_chatbots/prompts.py b/ai_chatbots/prompts.py index b4464142..46b49d6f 100644 --- a/ai_chatbots/prompts.py +++ b/ai_chatbots/prompts.py @@ -7,17 +7,15 @@ Your job: 1. Understand the user's intent AND BACKGROUND based on their message. -2. Use the available function to gather information or recommend courses. +2. Use the available tools to gather information or recommend courses. 3. Provide a clear, user-friendly explanation of your recommendations if search results are found. -Run the tool to find learning resources that the user is interested in, -and answer only based on the function search -results. - -VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO -ANSWER QUESTIONS. +Run the "search_courses" tool to find learning resources that the user is interested in, +and answer only based on the function search results. If the user asks for more +specific information about a particular resource, use the "search_content_files" tool +to find an answer, using the course_id from the search results as the course_id. If no results are returned, say you could not find any relevant resources. Don't say you're going to try again. Ask the user if they would like to @@ -48,10 +46,7 @@ Expected Output: Maybe ask whether the user wants to learn how to program, or just use AI in their discipline - does this person want to study machine learning? More info needed. Then perform a relevant search and send back the best results. - - -AGAIN: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO -ANSWER QUESTIONS.""" +""" PROMPT_SYLLABUS = """You are an assistant named Tim, helping users answer questions @@ -64,7 +59,8 @@ answer the user's question. Always use the tool results to answer questions, and answer only based on the tool -output. Do not include the course id in the query parameter. +output. Do not include the course_id in the query parameter. The tool always has +access to the course id. VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO ANSWER QUESTIONS. If no results are returned, say you could not find any relevant information.""" diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 4fe10b54..0335dd2b 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -18,7 +18,7 @@ class SearchToolSchema(pydantic.BaseModel): - """Schema for searching MIT learning resources. + """Schema to search for MIT learning resources. Attributes: q: The search query string @@ -128,7 +128,6 @@ def search_courses( Query the MIT API for learning resources, and return simplified results as a JSON string """ - params = {"q": q, "limit": settings.AI_MIT_SEARCH_LIMIT} valid_params = { @@ -139,7 +138,7 @@ def search_courses( } params.update({k: v for k, v in valid_params.items() if v is not None}) search_url = state["search_url"][-1] if state else settings.AI_MIT_SEARCH_URL - log.debug("Searching MIT API at %s with params: %s", search_url, params) + log.debug("Searching MIT resources API at %s with params: %s", search_url, params) try: response = requests.get(search_url, params=params, timeout=30) response.raise_for_status() @@ -148,6 +147,7 @@ def search_courses( main_properties = [ "title", "id", + "readable_id", "description", "offered_by", "free", @@ -160,6 +160,7 @@ def search_courses( simplified_result["url"] = ( f"{settings.AI_MIT_SEARCH_DETAIL_URL}{result.pop('id')}" ) + simplified_result["course_id"] = result.pop("readable_id", None) # Instructors and level will be in the runs data if present next_date = result.get("next_start_date", None) raw_runs = result.get("runs", []) @@ -192,6 +193,14 @@ class SearchContentFilesToolSchema(pydantic.BaseModel): "Query to find course information that might answer the user's question." ) ) + + course_id: Optional[str] = Field( + description=( + "The course ID to search for content files related to the course." + "Do not include the course ID in the q parameter." + ) + ) + state: Annotated[dict, InjectedState] = Field( description="The agent state, including course_id and collection_name params" ) @@ -212,15 +221,17 @@ class VideoGPTToolSchema(pydantic.BaseModel): @tool(args_schema=SearchContentFilesToolSchema) -def search_content_files(q: str, state: Annotated[dict, InjectedState]) -> str: +def search_content_files( + q: str, state: Annotated[dict, InjectedState], course_id: str | None = None +) -> str: """ Query the MIT contentfile vector endpoint API, and return results as a JSON string, along with metadata about the query parameters used. """ url = settings.AI_MIT_SYLLABUS_URL - course_id = state["course_id"][-1] - collection_name = state["collection_name"][-1] + course_id = state.get("course_id", [None])[-1] or course_id + collection_name = state.get("collection_name", [None])[-1] params = { "q": q, "resource_readable_id": course_id, @@ -228,7 +239,7 @@ def search_content_files(q: str, state: Annotated[dict, InjectedState]) -> str: } if collection_name: params["collection_name"] = collection_name - log.debug("Searching MIT API with params: %s", params) + log.debug("Searching MIT content API with params: %s", params) try: response = requests.get(url, params=params, timeout=30) response.raise_for_status() From b8c992128c18f033ecb62c766a6d535ef7883a11 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Thu, 29 May 2025 15:53:54 -0400 Subject: [PATCH 02/11] A couple more tweaks --- ai_chatbots/prompts.py | 2 +- ai_chatbots/tools.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/ai_chatbots/prompts.py b/ai_chatbots/prompts.py index 46b49d6f..e54ad4df 100644 --- a/ai_chatbots/prompts.py +++ b/ai_chatbots/prompts.py @@ -15,7 +15,7 @@ Run the "search_courses" tool to find learning resources that the user is interested in, and answer only based on the function search results. If the user asks for more specific information about a particular resource, use the "search_content_files" tool -to find an answer, using the course_id from the search results as the course_id. +to find an answer. If no results are returned, say you could not find any relevant resources. Don't say you're going to try again. Ask the user if they would like to diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 0335dd2b..9b8ba3c0 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -195,10 +195,7 @@ class SearchContentFilesToolSchema(pydantic.BaseModel): ) course_id: Optional[str] = Field( - description=( - "The course ID to search for content files related to the course." - "Do not include the course ID in the q parameter." - ) + description=("The course_id to use if not provided in the agent state. "), ) state: Annotated[dict, InjectedState] = Field( From bb2016c4ac715939ba3207bfd3b4b31b1dd6e674 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 30 May 2025 08:46:34 -0400 Subject: [PATCH 03/11] more tweaks --- ai_chatbots/tools.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 9b8ba3c0..3adac461 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -186,20 +186,22 @@ def search_courses( class SearchContentFilesToolSchema(pydantic.BaseModel): - """Schema for searching MIT contentfiles related to a particular course.""" + """ + Schema for searching MIT contentfiles related to a particular learning resource. + """ q: str = Field( - description=( - "Query to find course information that might answer the user's question." - ) + description=("Query to find requested information about a learning resource.") ) - course_id: Optional[str] = Field( - description=("The course_id to use if not provided in the agent state. "), + readable_id: Optional[str] = Field( + description=("The readable_id of the learning resource."), ) state: Annotated[dict, InjectedState] = Field( - description="The agent state, including course_id and collection_name params" + description=( + "Agent state, which may include course_id (readable_id) and collection_name" + ) ) @@ -219,15 +221,16 @@ class VideoGPTToolSchema(pydantic.BaseModel): @tool(args_schema=SearchContentFilesToolSchema) def search_content_files( - q: str, state: Annotated[dict, InjectedState], course_id: str | None = None + q: str, state: Annotated[dict, InjectedState], readable_id: str | None = None ) -> str: """ - Query the MIT contentfile vector endpoint API, and return results as a - JSON string, along with metadata about the query parameters used. + Search for detailed information about a particular MIT learning resource. + The resource is identified by its readable_id or course_id. """ url = settings.AI_MIT_SYLLABUS_URL - course_id = state.get("course_id", [None])[-1] or course_id + # Use the state course_id if available, otherwise use the provided course_id + course_id = state.get("course_id", [None])[-1] or readable_id collection_name = state.get("collection_name", [None])[-1] params = { "q": q, From 645d8588166d2a3fcf652f17d0f66657d1c1490f Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 30 May 2025 08:58:56 -0400 Subject: [PATCH 04/11] Fix tests --- ai_chatbots/chatbots_test.py | 1 + ai_chatbots/tools.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/ai_chatbots/chatbots_test.py b/ai_chatbots/chatbots_test.py index 5ba08e7d..cccdfd1a 100644 --- a/ai_chatbots/chatbots_test.py +++ b/ai_chatbots/chatbots_test.py @@ -147,6 +147,7 @@ async def test_recommendation_bot_tool(settings, mocker, search_results): retained_attributes = [ "title", "id", + "readable_id", "description", "offered_by", "free", diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 3adac461..2258c005 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -160,7 +160,6 @@ def search_courses( simplified_result["url"] = ( f"{settings.AI_MIT_SEARCH_DETAIL_URL}{result.pop('id')}" ) - simplified_result["course_id"] = result.pop("readable_id", None) # Instructors and level will be in the runs data if present next_date = result.get("next_start_date", None) raw_runs = result.get("runs", []) @@ -196,6 +195,7 @@ class SearchContentFilesToolSchema(pydantic.BaseModel): readable_id: Optional[str] = Field( description=("The readable_id of the learning resource."), + default=None, ) state: Annotated[dict, InjectedState] = Field( From 2c688ba3c7c548daca39e36b04986e752c68da2e Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Fri, 30 May 2025 16:23:39 -0400 Subject: [PATCH 05/11] Resolve some summarization issues --- ai_chatbots/api.py | 53 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 43d91ef9..103eb125 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -16,12 +16,12 @@ from langchain_core.messages.utils import count_tokens_approximately from langchain_core.prompt_values import ChatPromptValue from langchain_core.prompts import ChatPromptTemplate +from langgraph.utils.runnable import RunnableCallable from langmem.short_term import RunningSummary from langmem.short_term.summarization import ( DEFAULT_EXISTING_SUMMARY_PROMPT, DEFAULT_FINAL_SUMMARY_PROMPT, DEFAULT_INITIAL_SUMMARY_PROMPT, - SummarizationNode, SummarizationResult, TokenCounter, ) @@ -181,7 +181,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 # map tool call IDs to their corresponding tool messages tool_call_id_to_tool_message: dict[str, ToolMessage] = {} should_summarize = False - n_tokens_to_summarize = 0 for i in range(total_summarized_messages, len(messages)): message = messages[i] if message.id is None: @@ -205,7 +204,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 and total_n_tokens - n_tokens <= max_remaining_tokens and not should_summarize ): - n_tokens_to_summarize = n_tokens should_summarize = True idx = i @@ -215,19 +213,13 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 messages_to_summarize = messages[total_summarized_messages : idx + 1] # If the last message is an AI message with tool calls, - # include subsequent corresponding tool messages in the summary as well, - # to avoid issues w/ the LLM provider + # wait until the next user message to summarize. if ( messages_to_summarize and isinstance(messages_to_summarize[-1], AIMessage) - and (tool_calls := messages_to_summarize[-1].tool_calls) + and messages_to_summarize[-1].tool_calls ): - # Add any matching tool messages from our dictionary - for tool_call in tool_calls: - if tool_call["id"] in tool_call_id_to_tool_message: - tool_message = tool_call_id_to_tool_message[tool_call["id"]] - n_tokens_to_summarize += token_counter([tool_message]) - messages_to_summarize.append(tool_message) + messages_to_summarize = [] if messages_to_summarize: if running_summary: @@ -295,13 +287,43 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 ) -class CustomSummarizationNode(SummarizationNode): +class CustomSummarizationNode(RunnableCallable): """ Customized implementation of langmem.short_term.SummarizationNode. The original has a bug causing the most recent user question and answer to be lost when the summary was updated. """ + def __init__( # noqa: PLR0913 + self, + *, + model: LanguageModelLike, + max_tokens: int, + max_tokens_before_summary: int | None = None, + max_summary_tokens: int = 256, + token_counter: TokenCounter = count_tokens_approximately, + initial_summary_prompt: ChatPromptTemplate = DEFAULT_INITIAL_SUMMARY_PROMPT, + existing_summary_prompt: ChatPromptTemplate = DEFAULT_EXISTING_SUMMARY_PROMPT, + final_prompt: ChatPromptTemplate = DEFAULT_FINAL_SUMMARY_PROMPT, + input_messages_key: str = "messages", + output_messages_key: str = "summarized_messages", + name: str = "summarization", + ) -> None: + """ + Initialize the CustomSummarizationNode. + """ + super().__init__(self._func, name=name, trace=False) + self.model = model + self.max_tokens = max_tokens + self.max_tokens_before_summary = max_tokens_before_summary + self.max_summary_tokens = max_summary_tokens + self.token_counter = token_counter + self.initial_summary_prompt = initial_summary_prompt + self.existing_summary_prompt = existing_summary_prompt + self.final_prompt = final_prompt + self.input_messages_key = input_messages_key + self.output_messages_key = output_messages_key + def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]: """ Generate a summary if needed. @@ -323,6 +345,11 @@ def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]: raise ValueError(error) last_message = messages[-1] if messages else None + log.debug( + "SummarizationNode called with %d messages, last message: %s", + len(messages), + last_message if last_message else "N/A", + ) previous_summary = context.get("running_summary") log.debug("Previous summary:\n\n%s\n\n", previous_summary or "N/A") summarization_result = summarize_messages( From da7d68d83f37852b3efaed551b77601d2f52e119 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Mon, 2 Jun 2025 09:41:19 -0400 Subject: [PATCH 06/11] Fix tests --- ai_chatbots/api.py | 12 +----------- ai_chatbots/api_test.py | 30 +++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 103eb125..7e51eb46 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -7,7 +7,6 @@ from langchain_core.language_models import LanguageModelLike from langchain_core.messages import ( - AIMessage, AnyMessage, RemoveMessage, SystemMessage, @@ -144,7 +143,7 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 else: existing_system_message = None - if not messages: + if not messages or isinstance(messages[-1], ToolMessage): return SummarizationResult( running_summary=running_summary, messages=( @@ -212,15 +211,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 else: messages_to_summarize = messages[total_summarized_messages : idx + 1] - # If the last message is an AI message with tool calls, - # wait until the next user message to summarize. - if ( - messages_to_summarize - and isinstance(messages_to_summarize[-1], AIMessage) - and messages_to_summarize[-1].tool_calls - ): - messages_to_summarize = [] - if messages_to_summarize: if running_summary: summary_messages = cast( diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index ac3f731c..11294bd5 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -563,7 +563,9 @@ def test_subsequent_summarization_with_new_messages_approximate_token_counter(): assert len(updated_summary_value.summarized_message_ids) == len(messages2) - 3 -def test_last_ai_with_tool_calls(): +@pytest.mark.parametrize("is_tool_call", [True, False]) +def test_last_ai_with_tool_calls(is_tool_call): + """Summarization should be skipped if last message is a tool call.""" model = MockChatModel(responses=[AIMessage(content="Summary without tool calls.")]) messages = [ @@ -584,24 +586,34 @@ def test_last_ai_with_tool_calls(): HumanMessage(content="Message 2", id="6"), ] + if is_tool_call: + # If the last message is a tool call, we should not summarize + messages.append( + ToolMessage(content="Call tool 3", tool_call_id="3", name="tool_3", id="7") + ) + # Call the summarizer result = summarize_messages( messages, running_summary=None, model=model, token_counter=len, - max_tokens_before_summary=2, + max_tokens_before_summary=6, max_tokens=6, - max_summary_tokens=1, + max_summary_tokens=3, ) # Check that the AI message with tool calls was summarized together with the tool messages - assert len(result.messages) == 3 - assert result.messages[0].type == "system" # Summary - assert result.messages[-2:] == messages[-2:] - assert result.running_summary.summarized_message_ids == { - msg.id for msg in messages[:-2] - } + assert len(result.messages) == (7 if is_tool_call else 1) + assert result.messages[0].type == ("human" if is_tool_call else "system") + assert result.messages[-1].type == ("tool" if is_tool_call else "system") + + if is_tool_call: + assert result.running_summary is None + else: + assert result.running_summary.summarized_message_ids == { + msg.id for msg in messages + } def test_missing_message_ids(): From 76d75e259d9b68a4e3e5a8a3dbb86558c994c9c9 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Mon, 2 Jun 2025 12:07:54 -0400 Subject: [PATCH 07/11] Minor tweaks --- ai_chatbots/api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 7e51eb46..c97a79b5 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -143,6 +143,8 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 else: existing_system_message = None + # if there are no messages to summarize, or the last message + # is a tool call, do not invoke the summarization model. if not messages or isinstance(messages[-1], ToolMessage): return SummarizationResult( running_summary=running_summary, @@ -229,6 +231,7 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 ) log.debug("messages to summarize: %s", messages_to_summarize) summary_response = model.invoke(summary_messages.messages) + log.debug("Summarization response: %s", summary_response.content) summarized_message_ids = summarized_message_ids | { message.id for message in messages_to_summarize } From d4f343b24ccea7c6275e4338b5c4e7e8f30e2dbb Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Mon, 2 Jun 2025 16:20:48 -0400 Subject: [PATCH 08/11] experiment --- ai_chatbots/api.py | 37 +++++++++++-------------------------- ai_chatbots/api_test.py | 27 ++++++++++++++------------- 2 files changed, 25 insertions(+), 39 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index c97a79b5..201b8ff7 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -8,6 +8,7 @@ from langchain_core.language_models import LanguageModelLike from langchain_core.messages import ( AnyMessage, + HumanMessage, RemoveMessage, SystemMessage, ToolMessage, @@ -131,21 +132,16 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 max_tokens_before_summary = max_tokens max_tokens_to_summarize = max_tokens - # Adjust the remaining token budget to account for the summary to be added - max_remaining_tokens = max_tokens - max_summary_tokens # First handle system message if present if messages and isinstance(messages[0], SystemMessage): existing_system_message = messages[0] # remove the system message from the list of messages to summarize messages = messages[1:] - # adjust remaining token budget for the system msg to be re-added - max_remaining_tokens -= token_counter([existing_system_message]) else: existing_system_message = None - # if there are no messages to summarize, or the last message - # is a tool call, do not invoke the summarization model. - if not messages or isinstance(messages[-1], ToolMessage): + # Summarize only when last message is a human message + if not messages or (messages and not isinstance(messages[-1], HumanMessage)): return SummarizationResult( running_summary=running_summary, messages=( @@ -172,16 +168,13 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 total_summarized_messages = i + 1 break - # We will use this to ensure that the total number of resulting tokens - # will fit into max_tokens window. - total_n_tokens = token_counter(messages[total_summarized_messages:]) - # Go through messages to count tokens and find cutoff point n_tokens = 0 idx = max(0, total_summarized_messages - 1) # map tool call IDs to their corresponding tool messages tool_call_id_to_tool_message: dict[str, ToolMessage] = {} should_summarize = False + # Iterate through all but the most recent message for i in range(total_summarized_messages, len(messages)): message = messages[i] if message.id is None: @@ -199,19 +192,18 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 n_tokens += token_counter([message]) # Check if we've reached max_tokens_to_summarize - # and the remaining messages fit within the max_remaining_tokens budget - if ( - n_tokens >= max_tokens_before_summary - and total_n_tokens - n_tokens <= max_remaining_tokens - and not should_summarize - ): + if n_tokens >= max_tokens_before_summary and not should_summarize: should_summarize = True idx = i if not should_summarize: messages_to_summarize = [] else: - messages_to_summarize = messages[total_summarized_messages : idx + 1] + log.debug( + f"{total_summarized_messages} summarized messages out of {len(messages)} messages, idx is {idx}" + ) + # Summarize all but most recent message that hasn't already been summarized + messages_to_summarize = messages[total_summarized_messages:-1] if messages_to_summarize: if running_summary: @@ -338,17 +330,10 @@ def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]: raise ValueError(error) last_message = messages[-1] if messages else None - log.debug( - "SummarizationNode called with %d messages, last message: %s", - len(messages), - last_message if last_message else "N/A", - ) previous_summary = context.get("running_summary") log.debug("Previous summary:\n\n%s\n\n", previous_summary or "N/A") summarization_result = summarize_messages( - # If we are returning here from a tool call, don't include the last tool - # message or the preceding AI message that called it. - messages[:-2] if isinstance(last_message, ToolMessage) else messages, + messages, running_summary=previous_summary, model=self.model, max_tokens=self.max_tokens, diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index 11294bd5..26229a28 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -111,9 +111,9 @@ def test_summarize_first_time(): AIMessage(content="Response 2", id="4"), HumanMessage(content="Message 3", id="5"), AIMessage(content="Response 3", id="6"), - # these messages will be added to the result post-summarization HumanMessage(content="Message 4", id="7"), AIMessage(content="Response 4", id="8"), + # this message should be added to the result post-summarization HumanMessage(content="Latest message", id="9"), ] @@ -132,20 +132,21 @@ def test_summarize_first_time(): assert len(model.invoke_calls) == 1 # Check that the result has the expected structure: - # - First message should be a summary - # - Last 3 messages should be the last 3 original messages - assert len(result.messages) == 4 + # - First message should be a summary of all but last human message + # - Last message should be the last human message + assert len(result.messages) == 2 assert result.messages[0].type == "system" assert "summary" in result.messages[0].content.lower() - assert result.messages[-3:] == messages[-3:] + assert result.messages[-1] == messages[-1] # Check the summary value summary_value = result.running_summary assert summary_value is not None assert summary_value.summary == "This is a summary of the conversation." - assert summary_value.summarized_message_ids == {msg.id for msg in messages[:-3]} + assert summary_value.summarized_message_ids == {msg.id for msg in messages[:-1]} # Test subsequent invocation (no new summary needed) + messages.append(factories.HumanMessageFactory.create()) result = summarize_messages( messages, running_summary=summary_value, @@ -154,13 +155,13 @@ def test_summarize_first_time(): max_tokens=6, max_summary_tokens=max_summary_tokens, ) - assert len(result.messages) == 4 + assert len(result.messages) == 3 assert result.messages[0].type == "system" assert ( result.messages[0].content == "Summary of the conversation so far: This is a summary of the conversation." ) - assert result.messages[-3:] == messages[-3:] + assert result.messages[-1:] == messages[-1:] def test_max_tokens_before_summary(): @@ -250,9 +251,9 @@ def test_with_system_message(): AIMessage(content="Response 2", id="4"), HumanMessage(content="Message 3", id="5"), AIMessage(content="Response 3", id="6"), - # these messages will be added to the result post-summarization HumanMessage(content="Message 4", id="7"), AIMessage(content="Response 4", id="8"), + # this message will be added to the result post-summarization HumanMessage(content="Latest message", id="9"), ] @@ -271,7 +272,7 @@ def test_with_system_message(): # Check that model was called assert len(model.invoke_calls) == 1 - assert model.invoke_calls[0] == messages[1:7] + [ + assert model.invoke_calls[0] == messages[1:-1] + [ HumanMessage(content="Create a summary of the conversation above:") ] @@ -279,11 +280,11 @@ def test_with_system_message(): # - System message should be preserved # - Second message should be a summary of messages 2-5 # - Last 3 messages should be the last 3 original messages - assert len(result.messages) == 5 + assert len(result.messages) == 3 assert result.messages[0].type == "system" assert result.messages[1].type == "system" # Summary message assert "summary" in result.messages[1].content.lower() - assert result.messages[2:] == messages[-3:] + assert result.messages[-1] == messages[-1] def test_approximate_token_counter(): @@ -627,7 +628,7 @@ def test_missing_message_ids(): running_summary=None, model=MockChatModel(responses=[]), max_tokens=10, - max_summary_tokens=1, + max_summary_tokens=10, ) From 496530c412882b7e7fd22d8525ef65ad9edc8b57 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Mon, 2 Jun 2025 17:08:05 -0400 Subject: [PATCH 09/11] eorking --- ai_chatbots/api.py | 27 +++++++++++++++++---------- ai_chatbots/api_test.py | 30 +++++++++++++++--------------- 2 files changed, 32 insertions(+), 25 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 201b8ff7..2e6c7085 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -7,6 +7,7 @@ from langchain_core.language_models import LanguageModelLike from langchain_core.messages import ( + AIMessage, AnyMessage, HumanMessage, RemoveMessage, @@ -22,6 +23,7 @@ DEFAULT_EXISTING_SUMMARY_PROMPT, DEFAULT_FINAL_SUMMARY_PROMPT, DEFAULT_INITIAL_SUMMARY_PROMPT, + SummarizationNode, SummarizationResult, TokenCounter, ) @@ -188,22 +190,25 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 # Store tool messages by their tool_call_id for later reference if isinstance(message, ToolMessage) and message.tool_call_id: tool_call_id_to_tool_message[message.tool_call_id] = message + n_tokens += token_counter([message]) - # Check if we've reached max_tokens_to_summarize - if n_tokens >= max_tokens_before_summary and not should_summarize: - should_summarize = True - idx = i + # Check if we've reached max_tokens_to_summarize and + # final message is a valid type to end summarization on + # (not a tool message or AI tool call) + if n_tokens >= max_tokens_before_summary and \ + not should_summarize and not isinstance(message, ToolMessage) and \ + ( + not isinstance(message, AIMessage) or not message.tool_calls + ): + should_summarize = True + idx = i if not should_summarize: messages_to_summarize = [] else: - log.debug( - f"{total_summarized_messages} summarized messages out of {len(messages)} messages, idx is {idx}" - ) - # Summarize all but most recent message that hasn't already been summarized - messages_to_summarize = messages[total_summarized_messages:-1] + messages_to_summarize = messages[total_summarized_messages : idx + 1] if messages_to_summarize: if running_summary: @@ -333,7 +338,9 @@ def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]: previous_summary = context.get("running_summary") log.debug("Previous summary:\n\n%s\n\n", previous_summary or "N/A") summarization_result = summarize_messages( - messages, + # If we are returning here from a tool call, don't include the last tool + # message or the preceding AI message that called it. + messages[:-2] if isinstance(last_message, ToolMessage) else messages, running_summary=previous_summary, model=self.model, max_tokens=self.max_tokens, diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index 26229a28..86cd440f 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -111,9 +111,9 @@ def test_summarize_first_time(): AIMessage(content="Response 2", id="4"), HumanMessage(content="Message 3", id="5"), AIMessage(content="Response 3", id="6"), + # these messages will be added to the result post-summarization HumanMessage(content="Message 4", id="7"), AIMessage(content="Response 4", id="8"), - # this message should be added to the result post-summarization HumanMessage(content="Latest message", id="9"), ] @@ -132,21 +132,20 @@ def test_summarize_first_time(): assert len(model.invoke_calls) == 1 # Check that the result has the expected structure: - # - First message should be a summary of all but last human message - # - Last message should be the last human message - assert len(result.messages) == 2 + # - First message should be a summary + # - Last 3 messages should be the last 3 original messages + assert len(result.messages) == 4 assert result.messages[0].type == "system" assert "summary" in result.messages[0].content.lower() - assert result.messages[-1] == messages[-1] + assert result.messages[-3:] == messages[-3:] # Check the summary value summary_value = result.running_summary assert summary_value is not None assert summary_value.summary == "This is a summary of the conversation." - assert summary_value.summarized_message_ids == {msg.id for msg in messages[:-1]} + assert summary_value.summarized_message_ids == {msg.id for msg in messages[:-3]} # Test subsequent invocation (no new summary needed) - messages.append(factories.HumanMessageFactory.create()) result = summarize_messages( messages, running_summary=summary_value, @@ -155,13 +154,13 @@ def test_summarize_first_time(): max_tokens=6, max_summary_tokens=max_summary_tokens, ) - assert len(result.messages) == 3 + assert len(result.messages) == 4 assert result.messages[0].type == "system" assert ( result.messages[0].content == "Summary of the conversation so far: This is a summary of the conversation." ) - assert result.messages[-1:] == messages[-1:] + assert result.messages[-3:] == messages[-3:] def test_max_tokens_before_summary(): @@ -251,9 +250,9 @@ def test_with_system_message(): AIMessage(content="Response 2", id="4"), HumanMessage(content="Message 3", id="5"), AIMessage(content="Response 3", id="6"), + # these messages will be added to the result post-summarization HumanMessage(content="Message 4", id="7"), AIMessage(content="Response 4", id="8"), - # this message will be added to the result post-summarization HumanMessage(content="Latest message", id="9"), ] @@ -272,7 +271,7 @@ def test_with_system_message(): # Check that model was called assert len(model.invoke_calls) == 1 - assert model.invoke_calls[0] == messages[1:-1] + [ + assert model.invoke_calls[0] == messages[1:7] + [ HumanMessage(content="Create a summary of the conversation above:") ] @@ -280,11 +279,11 @@ def test_with_system_message(): # - System message should be preserved # - Second message should be a summary of messages 2-5 # - Last 3 messages should be the last 3 original messages - assert len(result.messages) == 3 + assert len(result.messages) == 5 assert result.messages[0].type == "system" assert result.messages[1].type == "system" # Summary message assert "summary" in result.messages[1].content.lower() - assert result.messages[-1] == messages[-1] + assert result.messages[2:] == messages[-3:] def test_approximate_token_counter(): @@ -621,14 +620,15 @@ def test_missing_message_ids(): messages = [ HumanMessage(content="Message 1", id="1"), AIMessage(content="Response"), # Missing ID + HumanMessage(content="Message 2", id="1"), ] with pytest.raises(ValueError, match="Messages are required to have ID field"): summarize_messages( messages, running_summary=None, model=MockChatModel(responses=[]), - max_tokens=10, - max_summary_tokens=10, + max_tokens=1, + max_summary_tokens=1, ) From 928d0dda9a9f8542019b7e9bc71aee6b89e95799 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Mon, 2 Jun 2025 17:31:02 -0400 Subject: [PATCH 10/11] Fix some more summarization/multtool call issues --- ai_chatbots/api.py | 31 +++++++++++++++++++------------ ai_chatbots/api_test.py | 18 ++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 2e6c7085..47cde60a 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -23,7 +23,6 @@ DEFAULT_EXISTING_SUMMARY_PROMPT, DEFAULT_FINAL_SUMMARY_PROMPT, DEFAULT_INITIAL_SUMMARY_PROMPT, - SummarizationNode, SummarizationResult, TokenCounter, ) @@ -134,11 +133,15 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 max_tokens_before_summary = max_tokens max_tokens_to_summarize = max_tokens + # Adjust the remaining token budget to account for the summary to be added + max_remaining_tokens = max_tokens - max_summary_tokens # First handle system message if present if messages and isinstance(messages[0], SystemMessage): existing_system_message = messages[0] # remove the system message from the list of messages to summarize messages = messages[1:] + # adjust remaining token budget for the system msg to be re-added + max_remaining_tokens -= token_counter([existing_system_message]) else: existing_system_message = None @@ -170,13 +173,16 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 total_summarized_messages = i + 1 break + # We will use this to ensure that the total number of resulting tokens + # will fit into max_tokens window. + total_n_tokens = token_counter(messages[total_summarized_messages:]) + # Go through messages to count tokens and find cutoff point n_tokens = 0 idx = max(0, total_summarized_messages - 1) # map tool call IDs to their corresponding tool messages tool_call_id_to_tool_message: dict[str, ToolMessage] = {} should_summarize = False - # Iterate through all but the most recent message for i in range(total_summarized_messages, len(messages)): message = messages[i] if message.id is None: @@ -190,20 +196,21 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 # Store tool messages by their tool_call_id for later reference if isinstance(message, ToolMessage) and message.tool_call_id: tool_call_id_to_tool_message[message.tool_call_id] = message - n_tokens += token_counter([message]) - # Check if we've reached max_tokens_to_summarize and + # Check if we've reached max_tokens_to_summarize and # final message is a valid type to end summarization on - # (not a tool message or AI tool call) - if n_tokens >= max_tokens_before_summary and \ - not should_summarize and not isinstance(message, ToolMessage) and \ - ( - not isinstance(message, AIMessage) or not message.tool_calls - ): - should_summarize = True - idx = i + # (not a tool message or AI tool) + if ( + n_tokens >= max_tokens_before_summary + and total_n_tokens - n_tokens <= max_remaining_tokens + and not should_summarize + and not isinstance(message, ToolMessage) + and (not isinstance(message, AIMessage) or not message.tool_calls) + ): + should_summarize = True + idx = i if not should_summarize: messages_to_summarize = [] diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index 86cd440f..5d41d649 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -598,21 +598,21 @@ def test_last_ai_with_tool_calls(is_tool_call): running_summary=None, model=model, token_counter=len, - max_tokens_before_summary=6, - max_tokens=6, - max_summary_tokens=3, + max_tokens_before_summary=2, + max_tokens=2, + max_summary_tokens=1, ) # Check that the AI message with tool calls was summarized together with the tool messages - assert len(result.messages) == (7 if is_tool_call else 1) + assert len(result.messages) == (7 if is_tool_call else 2) assert result.messages[0].type == ("human" if is_tool_call else "system") - assert result.messages[-1].type == ("tool" if is_tool_call else "system") + assert result.messages[-1].type == ("tool" if is_tool_call else "human") if is_tool_call: assert result.running_summary is None else: assert result.running_summary.summarized_message_ids == { - msg.id for msg in messages + msg.id for msg in messages[:-1] } @@ -653,8 +653,10 @@ def test_duplicate_message_ids(): # Second summarization with a duplicate ID messages2 = [ + AIMessage(content="Response 1", id="2"), # Duplicate ID + HumanMessage(content="Message 2", id="3"), # Duplicate ID AIMessage(content="Response 2", id="4"), - HumanMessage(content="Message 3", id="1"), # Duplicate ID + HumanMessage(content="Message 3", id="5"), ] with pytest.raises(ValueError, match="has already been summarized"): @@ -663,7 +665,7 @@ def test_duplicate_message_ids(): running_summary=result.running_summary, model=model, token_counter=len, - max_tokens=5, + max_tokens=6, max_summary_tokens=1, ) From 60332909ee3441cbc35e16741d93e54c7f6ee04a Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Tue, 3 Jun 2025 10:40:36 -0400 Subject: [PATCH 11/11] Improve summary prompt --- ai_chatbots/prompts.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ai_chatbots/prompts.py b/ai_chatbots/prompts.py index e54ad4df..b7877676 100644 --- a/ai_chatbots/prompts.py +++ b/ai_chatbots/prompts.py @@ -83,9 +83,11 @@ # The following prompts are similar or identical to the default ones in # langmem.short_term.summarization -PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above. -If there are any tool results, include the full output of the latest one in -the summary. +PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above, incorporating +thevprevious summary if any. +If there are any tool results, include the full output of the latest tool message in +the summary. You must also retain all title and readable_id field values from all tool +messages and any previous summaries in this new summary. """ PROMPT_SUMMARY_EXISTING = """This is summary of the conversation so far: {existing_summary}