diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 92b7d4e2..47cde60a 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -9,6 +9,7 @@ from langchain_core.messages import ( AIMessage, AnyMessage, + HumanMessage, RemoveMessage, SystemMessage, ToolMessage, @@ -16,12 +17,12 @@ from langchain_core.messages.utils import count_tokens_approximately from langchain_core.prompt_values import ChatPromptValue from langchain_core.prompts import ChatPromptTemplate +from langgraph.utils.runnable import RunnableCallable from langmem.short_term import RunningSummary from langmem.short_term.summarization import ( DEFAULT_EXISTING_SUMMARY_PROMPT, DEFAULT_FINAL_SUMMARY_PROMPT, DEFAULT_INITIAL_SUMMARY_PROMPT, - SummarizationNode, SummarizationResult, TokenCounter, ) @@ -60,7 +61,9 @@ def get_search_tool_metadata(thread_id: str, latest_state: TypedDict) -> str: } return json.dumps(metadata) except json.JSONDecodeError: - log.exception("Error parsing tool metadata, not valid JSON") + log.exception( + "Error parsing tool metadata, not valid JSON: %s", msg_content + ) return json.dumps( {"error": "Error parsing tool metadata", "content": msg_content} ) @@ -142,7 +145,8 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 else: existing_system_message = None - if not messages: + # Summarize only when last message is a human message + if not messages or (messages and not isinstance(messages[-1], HumanMessage)): return SummarizationResult( running_summary=running_summary, messages=( @@ -179,7 +183,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 # map tool call IDs to their corresponding tool messages tool_call_id_to_tool_message: dict[str, ToolMessage] = {} should_summarize = False - n_tokens_to_summarize = 0 for i in range(total_summarized_messages, len(messages)): message = messages[i] if message.id is None: @@ -196,14 +199,16 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 n_tokens += token_counter([message]) - # Check if we've reached max_tokens_to_summarize - # and the remaining messages fit within the max_remaining_tokens budget + # Check if we've reached max_tokens_to_summarize and + # final message is a valid type to end summarization on + # (not a tool message or AI tool) if ( n_tokens >= max_tokens_before_summary and total_n_tokens - n_tokens <= max_remaining_tokens and not should_summarize + and not isinstance(message, ToolMessage) + and (not isinstance(message, AIMessage) or not message.tool_calls) ): - n_tokens_to_summarize = n_tokens should_summarize = True idx = i @@ -212,21 +217,6 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 else: messages_to_summarize = messages[total_summarized_messages : idx + 1] - # If the last message is an AI message with tool calls, - # include subsequent corresponding tool messages in the summary as well, - # to avoid issues w/ the LLM provider - if ( - messages_to_summarize - and isinstance(messages_to_summarize[-1], AIMessage) - and (tool_calls := messages_to_summarize[-1].tool_calls) - ): - # Add any matching tool messages from our dictionary - for tool_call in tool_calls: - if tool_call["id"] in tool_call_id_to_tool_message: - tool_message = tool_call_id_to_tool_message[tool_call["id"]] - n_tokens_to_summarize += token_counter([tool_message]) - messages_to_summarize.append(tool_message) - if messages_to_summarize: if running_summary: summary_messages = cast( @@ -245,6 +235,7 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 ) log.debug("messages to summarize: %s", messages_to_summarize) summary_response = model.invoke(summary_messages.messages) + log.debug("Summarization response: %s", summary_response.content) summarized_message_ids = summarized_message_ids | { message.id for message in messages_to_summarize } @@ -293,13 +284,43 @@ def summarize_messages( # noqa: PLR0912, PLR0913, PLR0915, C901 ) -class CustomSummarizationNode(SummarizationNode): +class CustomSummarizationNode(RunnableCallable): """ Customized implementation of langmem.short_term.SummarizationNode. The original has a bug causing the most recent user question and answer to be lost when the summary was updated. """ + def __init__( # noqa: PLR0913 + self, + *, + model: LanguageModelLike, + max_tokens: int, + max_tokens_before_summary: int | None = None, + max_summary_tokens: int = 256, + token_counter: TokenCounter = count_tokens_approximately, + initial_summary_prompt: ChatPromptTemplate = DEFAULT_INITIAL_SUMMARY_PROMPT, + existing_summary_prompt: ChatPromptTemplate = DEFAULT_EXISTING_SUMMARY_PROMPT, + final_prompt: ChatPromptTemplate = DEFAULT_FINAL_SUMMARY_PROMPT, + input_messages_key: str = "messages", + output_messages_key: str = "summarized_messages", + name: str = "summarization", + ) -> None: + """ + Initialize the CustomSummarizationNode. + """ + super().__init__(self._func, name=name, trace=False) + self.model = model + self.max_tokens = max_tokens + self.max_tokens_before_summary = max_tokens_before_summary + self.max_summary_tokens = max_summary_tokens + self.token_counter = token_counter + self.initial_summary_prompt = initial_summary_prompt + self.existing_summary_prompt = existing_summary_prompt + self.final_prompt = final_prompt + self.input_messages_key = input_messages_key + self.output_messages_key = output_messages_key + def _func(self, node_input: dict[str, Any] | BaseModel) -> dict[str, Any]: """ Generate a summary if needed. diff --git a/ai_chatbots/api_test.py b/ai_chatbots/api_test.py index ac3f731c..5d41d649 100644 --- a/ai_chatbots/api_test.py +++ b/ai_chatbots/api_test.py @@ -563,7 +563,9 @@ def test_subsequent_summarization_with_new_messages_approximate_token_counter(): assert len(updated_summary_value.summarized_message_ids) == len(messages2) - 3 -def test_last_ai_with_tool_calls(): +@pytest.mark.parametrize("is_tool_call", [True, False]) +def test_last_ai_with_tool_calls(is_tool_call): + """Summarization should be skipped if last message is a tool call.""" model = MockChatModel(responses=[AIMessage(content="Summary without tool calls.")]) messages = [ @@ -584,6 +586,12 @@ def test_last_ai_with_tool_calls(): HumanMessage(content="Message 2", id="6"), ] + if is_tool_call: + # If the last message is a tool call, we should not summarize + messages.append( + ToolMessage(content="Call tool 3", tool_call_id="3", name="tool_3", id="7") + ) + # Call the summarizer result = summarize_messages( messages, @@ -591,30 +599,35 @@ def test_last_ai_with_tool_calls(): model=model, token_counter=len, max_tokens_before_summary=2, - max_tokens=6, + max_tokens=2, max_summary_tokens=1, ) # Check that the AI message with tool calls was summarized together with the tool messages - assert len(result.messages) == 3 - assert result.messages[0].type == "system" # Summary - assert result.messages[-2:] == messages[-2:] - assert result.running_summary.summarized_message_ids == { - msg.id for msg in messages[:-2] - } + assert len(result.messages) == (7 if is_tool_call else 2) + assert result.messages[0].type == ("human" if is_tool_call else "system") + assert result.messages[-1].type == ("tool" if is_tool_call else "human") + + if is_tool_call: + assert result.running_summary is None + else: + assert result.running_summary.summarized_message_ids == { + msg.id for msg in messages[:-1] + } def test_missing_message_ids(): messages = [ HumanMessage(content="Message 1", id="1"), AIMessage(content="Response"), # Missing ID + HumanMessage(content="Message 2", id="1"), ] with pytest.raises(ValueError, match="Messages are required to have ID field"): summarize_messages( messages, running_summary=None, model=MockChatModel(responses=[]), - max_tokens=10, + max_tokens=1, max_summary_tokens=1, ) @@ -640,8 +653,10 @@ def test_duplicate_message_ids(): # Second summarization with a duplicate ID messages2 = [ + AIMessage(content="Response 1", id="2"), # Duplicate ID + HumanMessage(content="Message 2", id="3"), # Duplicate ID AIMessage(content="Response 2", id="4"), - HumanMessage(content="Message 3", id="1"), # Duplicate ID + HumanMessage(content="Message 3", id="5"), ] with pytest.raises(ValueError, match="has already been summarized"): @@ -650,7 +665,7 @@ def test_duplicate_message_ids(): running_summary=result.running_summary, model=model, token_counter=len, - max_tokens=5, + max_tokens=6, max_summary_tokens=1, ) diff --git a/ai_chatbots/chatbots.py b/ai_chatbots/chatbots.py index 4d12b536..a94edbaa 100644 --- a/ai_chatbots/chatbots.py +++ b/ai_chatbots/chatbots.py @@ -41,7 +41,6 @@ from ai_chatbots.api import CustomSummarizationNode, get_search_tool_metadata from ai_chatbots.models import TutorBotOutput from ai_chatbots.prompts import PROMPT_MAPPING -from ai_chatbots.tools import get_video_transcript_chunk, search_content_files from ai_chatbots.utils import get_django_cache log = logging.getLogger(__name__) @@ -405,7 +404,7 @@ def __init__( # noqa: PLR0913 def create_tools(self) -> list[BaseTool]: """Create tools required by the agent""" - return [tools.search_courses] + return [tools.search_courses, tools.search_content_files] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" @@ -457,7 +456,7 @@ def __init__( # noqa: PLR0913 def create_tools(self): """Create tools required by the agent""" - return [search_content_files] + return [tools.search_content_files] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" @@ -687,7 +686,7 @@ def __init__( # noqa: PLR0913 def create_tools(self): """Create tools required for the agent""" - return [get_video_transcript_chunk] + return [tools.get_video_transcript_chunk] async def get_tool_metadata(self) -> str: """Return the metadata for the search tool""" diff --git a/ai_chatbots/chatbots_test.py b/ai_chatbots/chatbots_test.py index 5ba08e7d..cccdfd1a 100644 --- a/ai_chatbots/chatbots_test.py +++ b/ai_chatbots/chatbots_test.py @@ -147,6 +147,7 @@ async def test_recommendation_bot_tool(settings, mocker, search_results): retained_attributes = [ "title", "id", + "readable_id", "description", "offered_by", "free", diff --git a/ai_chatbots/prompts.py b/ai_chatbots/prompts.py index b4464142..b7877676 100644 --- a/ai_chatbots/prompts.py +++ b/ai_chatbots/prompts.py @@ -7,17 +7,15 @@ Your job: 1. Understand the user's intent AND BACKGROUND based on their message. -2. Use the available function to gather information or recommend courses. +2. Use the available tools to gather information or recommend courses. 3. Provide a clear, user-friendly explanation of your recommendations if search results are found. -Run the tool to find learning resources that the user is interested in, -and answer only based on the function search -results. - -VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO -ANSWER QUESTIONS. +Run the "search_courses" tool to find learning resources that the user is interested in, +and answer only based on the function search results. If the user asks for more +specific information about a particular resource, use the "search_content_files" tool +to find an answer. If no results are returned, say you could not find any relevant resources. Don't say you're going to try again. Ask the user if they would like to @@ -48,10 +46,7 @@ Expected Output: Maybe ask whether the user wants to learn how to program, or just use AI in their discipline - does this person want to study machine learning? More info needed. Then perform a relevant search and send back the best results. - - -AGAIN: NEVER USE ANY INFORMATION OUTSIDE OF THE MIT SEARCH RESULTS TO -ANSWER QUESTIONS.""" +""" PROMPT_SYLLABUS = """You are an assistant named Tim, helping users answer questions @@ -64,7 +59,8 @@ answer the user's question. Always use the tool results to answer questions, and answer only based on the tool -output. Do not include the course id in the query parameter. +output. Do not include the course_id in the query parameter. The tool always has +access to the course id. VERY IMPORTANT: NEVER USE ANY INFORMATION OUTSIDE OF THE TOOL OUTPUT TO ANSWER QUESTIONS. If no results are returned, say you could not find any relevant information.""" @@ -87,9 +83,11 @@ # The following prompts are similar or identical to the default ones in # langmem.short_term.summarization -PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above. -If there are any tool results, include the full output of the latest one in -the summary. +PROMPT_SUMMARY_INITIAL = """Create a summary of the conversation above, incorporating +thevprevious summary if any. +If there are any tool results, include the full output of the latest tool message in +the summary. You must also retain all title and readable_id field values from all tool +messages and any previous summaries in this new summary. """ PROMPT_SUMMARY_EXISTING = """This is summary of the conversation so far: {existing_summary} diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 4fe10b54..2258c005 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -18,7 +18,7 @@ class SearchToolSchema(pydantic.BaseModel): - """Schema for searching MIT learning resources. + """Schema to search for MIT learning resources. Attributes: q: The search query string @@ -128,7 +128,6 @@ def search_courses( Query the MIT API for learning resources, and return simplified results as a JSON string """ - params = {"q": q, "limit": settings.AI_MIT_SEARCH_LIMIT} valid_params = { @@ -139,7 +138,7 @@ def search_courses( } params.update({k: v for k, v in valid_params.items() if v is not None}) search_url = state["search_url"][-1] if state else settings.AI_MIT_SEARCH_URL - log.debug("Searching MIT API at %s with params: %s", search_url, params) + log.debug("Searching MIT resources API at %s with params: %s", search_url, params) try: response = requests.get(search_url, params=params, timeout=30) response.raise_for_status() @@ -148,6 +147,7 @@ def search_courses( main_properties = [ "title", "id", + "readable_id", "description", "offered_by", "free", @@ -185,15 +185,23 @@ def search_courses( class SearchContentFilesToolSchema(pydantic.BaseModel): - """Schema for searching MIT contentfiles related to a particular course.""" + """ + Schema for searching MIT contentfiles related to a particular learning resource. + """ q: str = Field( - description=( - "Query to find course information that might answer the user's question." - ) + description=("Query to find requested information about a learning resource.") ) + + readable_id: Optional[str] = Field( + description=("The readable_id of the learning resource."), + default=None, + ) + state: Annotated[dict, InjectedState] = Field( - description="The agent state, including course_id and collection_name params" + description=( + "Agent state, which may include course_id (readable_id) and collection_name" + ) ) @@ -212,15 +220,18 @@ class VideoGPTToolSchema(pydantic.BaseModel): @tool(args_schema=SearchContentFilesToolSchema) -def search_content_files(q: str, state: Annotated[dict, InjectedState]) -> str: +def search_content_files( + q: str, state: Annotated[dict, InjectedState], readable_id: str | None = None +) -> str: """ - Query the MIT contentfile vector endpoint API, and return results as a - JSON string, along with metadata about the query parameters used. + Search for detailed information about a particular MIT learning resource. + The resource is identified by its readable_id or course_id. """ url = settings.AI_MIT_SYLLABUS_URL - course_id = state["course_id"][-1] - collection_name = state["collection_name"][-1] + # Use the state course_id if available, otherwise use the provided course_id + course_id = state.get("course_id", [None])[-1] or readable_id + collection_name = state.get("collection_name", [None])[-1] params = { "q": q, "resource_readable_id": course_id, @@ -228,7 +239,7 @@ def search_content_files(q: str, state: Annotated[dict, InjectedState]) -> str: } if collection_name: params["collection_name"] = collection_name - log.debug("Searching MIT API with params: %s", params) + log.debug("Searching MIT content API with params: %s", params) try: response = requests.get(url, params=params, timeout=30) response.raise_for_status()