From 1cfb0aa235dd3f8eb4fc9cb38fee54011439b450 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 19:31:20 +0000 Subject: [PATCH 01/15] Backport conversation_memory.py from PR #6909 --- .../agenthub/codeact_agent/codeact_agent.py | 38 +- openhands/memory/conversation_memory.py | 408 ++++++++++++++++++ tests/unit/test_conversation_memory.py | 105 +++++ 3 files changed, 527 insertions(+), 24 deletions(-) create mode 100644 openhands/memory/conversation_memory.py create mode 100644 tests/unit/test_conversation_memory.py diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index b636e40cb9f6..6b707aecb258 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -9,16 +9,13 @@ from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger from openhands.core.message import Message, TextContent -from openhands.core.message_utils import ( - apply_prompt_caching, - events_to_messages, -) from openhands.events.action import ( Action, AgentFinishAction, ) from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser +from openhands.memory.conversation_memory import ConversationMemory from openhands.runtime.plugins import ( AgentSkillsRequirement, JupyterRequirement, @@ -90,6 +87,9 @@ def __init__( disabled_microagents=self.config.disabled_microagents, ) + # Create a ConversationMemory instance + self.conversation_memory = ConversationMemory(self.prompt_manager) + self.condenser = Condenser.from_config(self.config.condenser) logger.debug(f'Using condenser: {self.condenser}') @@ -168,13 +168,18 @@ def _get_messages(self, state: State) -> list[Message]: if not self.prompt_manager: raise Exception('Prompt Manager not instantiated.') - messages: list[Message] = self._initial_messages() + # Use conversation_memory to process events instead of calling events_to_messages directly + messages = self.conversation_memory.process_initial_messages( + with_caching=self.llm.is_caching_prompt_active() + ) # Condense the events from the state. events = self.condenser.condensed_history(state) - messages += events_to_messages( - events, + messages = self.conversation_memory.process_events( + state=state, + condensed_history=events, + initial_messages=messages, max_message_chars=self.llm.config.max_message_chars, vision_is_active=self.llm.vision_is_active(), enable_som_visual_browsing=self.config.enable_som_visual_browsing, @@ -183,26 +188,11 @@ def _get_messages(self, state: State) -> list[Message]: messages = self._enhance_messages(messages) if self.llm.is_caching_prompt_active(): - apply_prompt_caching(messages) + # Use conversation_memory to apply caching instead of calling apply_prompt_caching directly + self.conversation_memory.apply_prompt_caching(messages) return messages - def _initial_messages(self) -> list[Message]: - """Creates the initial messages (including the system prompt) for the LLM conversation.""" - assert self.prompt_manager, 'Prompt Manager not instantiated.' - - return [ - Message( - role='system', - content=[ - TextContent( - text=self.prompt_manager.get_system_message(), - cache_prompt=self.llm.is_caching_prompt_active(), - ) - ], - ) - ] - def _enhance_messages(self, messages: list[Message]) -> list[Message]: """Enhances the user message with additional context based on keywords matched. diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py new file mode 100644 index 000000000000..3653e7ba70c2 --- /dev/null +++ b/openhands/memory/conversation_memory.py @@ -0,0 +1,408 @@ +import json + +from litellm import ModelResponse + +from openhands.controller.state.state import State +from openhands.core.logger import openhands_logger as logger +from openhands.core.message import ImageContent, Message, TextContent +from openhands.core.schema import ActionType +from openhands.events.action import ( + Action, + AgentDelegateAction, + AgentFinishAction, + BrowseInteractiveAction, + BrowseURLAction, + CmdRunAction, + FileEditAction, + FileReadAction, + IPythonRunCellAction, + MessageAction, +) +from openhands.events.event import Event +from openhands.events.observation import ( + AgentCondensationObservation, + AgentDelegateObservation, + BrowserOutputObservation, + CmdOutputObservation, + FileEditObservation, + FileReadObservation, + IPythonRunCellObservation, + UserRejectObservation, +) +from openhands.events.observation.error import ErrorObservation +from openhands.events.observation.observation import Observation +from openhands.events.serialization.event import truncate_content +from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo + + +class ConversationMemory: + """Processes event history into a coherent conversation for the agent.""" + + def __init__(self, prompt_manager: PromptManager): + self.prompt_manager = prompt_manager + + def process_events( + self, + state: State, + condensed_history: list[Event], + initial_messages: list[Message], + max_message_chars: int | None = None, + vision_is_active: bool = False, + enable_som_visual_browsing: bool = False, + ) -> list[Message]: + """Process state history into a list of messages for the LLM. + + Ensures that tool call actions are processed correctly in function calling mode. + + Args: + state: The state containing the history of events to convert + condensed_history: The condensed list of events to process + initial_messages: The initial messages to include in the result + max_message_chars: The maximum number of characters in the content of an event included + in the prompt to the LLM. Larger observations are truncated. + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. + enable_som_visual_browsing: Whether to enable visual browsing for the SOM model. + """ + events = condensed_history + + logger.debug( + f'Processing {len(events)} events from a total of {len(state.history)} events' + ) + + # Process special events first (system prompts, etc.) + messages = initial_messages + + # Process regular events + pending_tool_call_action_messages: dict[str, Message] = {} + tool_call_id_to_message: dict[str, Message] = {} + + for event in events: + # create a regular message from an event + if isinstance(event, Action): + messages_to_add = self._process_action( + action=event, + pending_tool_call_action_messages=pending_tool_call_action_messages, + vision_is_active=vision_is_active, + ) + elif isinstance(event, Observation): + messages_to_add = self._process_observation( + obs=event, + tool_call_id_to_message=tool_call_id_to_message, + max_message_chars=max_message_chars, + vision_is_active=vision_is_active, + enable_som_visual_browsing=enable_som_visual_browsing, + ) + else: + raise ValueError(f'Unknown event type: {type(event)}') + + # Check pending tool call action messages and see if they are complete + _response_ids_to_remove = [] + for ( + response_id, + pending_message, + ) in pending_tool_call_action_messages.items(): + assert pending_message.tool_calls is not None, ( + 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' + f'Pending message: {pending_message}' + ) + if all( + tool_call.id in tool_call_id_to_message + for tool_call in pending_message.tool_calls + ): + # If complete: + # -- 1. Add the message that **initiated** the tool calls + messages_to_add.append(pending_message) + # -- 2. Add the tool calls **results*** + for tool_call in pending_message.tool_calls: + messages_to_add.append(tool_call_id_to_message[tool_call.id]) + tool_call_id_to_message.pop(tool_call.id) + _response_ids_to_remove.append(response_id) + # Cleanup the processed pending tool messages + for response_id in _response_ids_to_remove: + pending_tool_call_action_messages.pop(response_id) + + messages += messages_to_add + + return messages + + def process_initial_messages(self, with_caching: bool = False) -> list[Message]: + """Create the initial messages for the conversation.""" + return [ + Message( + role='system', + content=[ + TextContent( + text=self.prompt_manager.get_system_message(), + cache_prompt=with_caching, + ) + ], + ) + ] + + def _process_action( + self, + action: Action, + pending_tool_call_action_messages: dict[str, Message], + vision_is_active: bool = False, + ) -> list[Message]: + """Converts an action into a message format that can be sent to the LLM. + + This method handles different types of actions and formats them appropriately: + 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish: + - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages + - In non-function calling mode: Creates a message with the action string + 2. For MessageActions: Creates a message with the text content and optional image content + + Args: + action: The action to convert. Can be one of: + - CmdRunAction: For executing bash commands + - IPythonRunCellAction: For running IPython code + - FileEditAction: For editing files + - FileReadAction: For reading files using openhands-aci commands + - BrowseInteractiveAction: For browsing the web + - AgentFinishAction: For ending the interaction + - MessageAction: For sending messages + + pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages. + Used in function calling mode to track tool calls that are waiting for their results. + + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included + + Returns: + list[Message]: A list containing the formatted message(s) for the action. + May be empty if the action is handled as a tool call in function calling mode. + + Note: + In function calling mode, tool-based actions are stored in pending_tool_call_action_messages + rather than being returned immediately. They will be processed later when all corresponding + tool call results are available. + """ + # create a regular message from an event + if isinstance( + action, + ( + AgentDelegateAction, + IPythonRunCellAction, + FileEditAction, + FileReadAction, + BrowseInteractiveAction, + BrowseURLAction, + ), + ) or (isinstance(action, CmdRunAction) and action.source == 'agent'): + tool_metadata = action.tool_call_metadata + assert tool_metadata is not None, ( + 'Tool call metadata should NOT be None when function calling is enabled. Action: ' + + str(action) + ) + + llm_response: ModelResponse = tool_metadata.model_response + assistant_msg = getattr(llm_response.choices[0], 'message') + + # Add the LLM message (assistant) that initiated the tool calls + # (overwrites any previous message with the same response_id) + logger.debug( + f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}' + ) + pending_tool_call_action_messages[llm_response.id] = Message( + role=getattr(assistant_msg, 'role', 'assistant'), + # tool call content SHOULD BE a string + content=[TextContent(text=assistant_msg.content or '')] + if assistant_msg.content is not None + else [], + tool_calls=assistant_msg.tool_calls, + ) + return [] + elif isinstance(action, AgentFinishAction): + role = 'user' if action.source == 'user' else 'assistant' + + # when agent finishes, it has tool_metadata + # which has already been executed, and it doesn't have a response + # when the user finishes (/exit), we don't have tool_metadata + tool_metadata = action.tool_call_metadata + if tool_metadata is not None: + # take the response message from the tool call + assistant_msg = getattr( + tool_metadata.model_response.choices[0], 'message' + ) + content = assistant_msg.content or '' + + # save content if any, to thought + if action.thought: + if action.thought != content: + action.thought += '\n' + content + else: + action.thought = content + + # remove the tool call metadata + action.tool_call_metadata = None + if role not in ('user', 'system', 'assistant', 'tool'): + raise ValueError(f'Invalid role: {role}') + return [ + Message( + role=role, # type: ignore[arg-type] + content=[TextContent(text=action.thought)], + ) + ] + elif isinstance(action, MessageAction): + role = 'user' if action.source == 'user' else 'assistant' + content = [TextContent(text=action.content or '')] + if vision_is_active and action.image_urls: + content.append(ImageContent(image_urls=action.image_urls)) + if role not in ('user', 'system', 'assistant', 'tool'): + raise ValueError(f'Invalid role: {role}') + return [ + Message( + role=role, # type: ignore[arg-type] + content=content, + ) + ] + elif isinstance(action, CmdRunAction) and action.source == 'user': + content = [ + TextContent(text=f'User executed the command:\n{action.command}') + ] + return [ + Message( + role='user', # Always user for CmdRunAction + content=content, + ) + ] + return [] + + def _process_observation( + self, + obs: Observation, + tool_call_id_to_message: dict[str, Message], + max_message_chars: int | None = None, + vision_is_active: bool = False, + enable_som_visual_browsing: bool = False, + ) -> list[Message]: + """Converts an observation into a message format that can be sent to the LLM. + + This method handles different types of observations and formats them appropriately: + - CmdOutputObservation: Formats command execution results with exit codes + - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images + - FileEditObservation: Formats file editing results + - FileReadObservation: Formats file reading results from openhands-aci + - AgentDelegateObservation: Formats results from delegated agent tasks + - ErrorObservation: Formats error messages from failed actions + - UserRejectObservation: Formats user rejection messages + + In function calling mode, observations with tool_call_metadata are stored in + tool_call_id_to_message for later processing instead of being returned immediately. + + Args: + obs: The observation to convert + tool_call_id_to_message: Dictionary mapping tool call IDs to their corresponding messages (used in function calling mode) + max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM + vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included + enable_som_visual_browsing: Whether to enable visual browsing for the SOM model + + Returns: + list[Message]: A list containing the formatted message(s) for the observation. + May be empty if the observation is handled as a tool response in function calling mode. + + Raises: + ValueError: If the observation type is unknown + """ + message: Message + + if isinstance(obs, CmdOutputObservation): + # if it doesn't have tool call metadata, it was triggered by a user action + if obs.tool_call_metadata is None: + text = truncate_content( + f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}', + max_message_chars, + ) + else: + text = truncate_content(obs.to_agent_observation(), max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, IPythonRunCellObservation): + text = obs.content + # replace base64 images with a placeholder + splitted = text.split('\n') + for i, line in enumerate(splitted): + if '![image](data:image/png;base64,' in line: + splitted[i] = ( + '![image](data:image/png;base64, ...) already displayed to user' + ) + text = '\n'.join(splitted) + text = truncate_content(text, max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, FileEditObservation): + text = truncate_content(str(obs), max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, FileReadObservation): + message = Message( + role='user', content=[TextContent(text=obs.content)] + ) # Content is already truncated by openhands-aci + elif isinstance(obs, BrowserOutputObservation): + text = obs.get_agent_obs_text() + if ( + obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE + and obs.set_of_marks is not None + and len(obs.set_of_marks) > 0 + and enable_som_visual_browsing + and vision_is_active + ): + text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n' + message = Message( + role='user', + content=[ + TextContent(text=text), + ImageContent(image_urls=[obs.set_of_marks]), + ], + ) + else: + message = Message( + role='user', + content=[TextContent(text=text)], + ) + elif isinstance(obs, AgentDelegateObservation): + text = truncate_content( + obs.outputs['content'] if 'content' in obs.outputs else '', + max_message_chars, + ) + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, ErrorObservation): + text = truncate_content(obs.content, max_message_chars) + text += '\n[Error occurred in processing last action]' + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, UserRejectObservation): + text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) + text += '\n[Last action has been rejected by the user]' + message = Message(role='user', content=[TextContent(text=text)]) + elif isinstance(obs, AgentCondensationObservation): + text = truncate_content(obs.content, max_message_chars) + message = Message(role='user', content=[TextContent(text=text)]) + else: + # If an observation message is not returned, it will cause an error + # when the LLM tries to return the next message + raise ValueError(f'Unknown observation type: {type(obs)}') + + # Update the message as tool response properly + if (tool_call_metadata := getattr(obs, 'tool_call_metadata', None)) is not None: + tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( + role='tool', + content=message.content, + tool_call_id=tool_call_metadata.tool_call_id, + name=tool_call_metadata.function_name, + ) + # No need to return the observation message + # because it will be added by get_action_message when all the corresponding + # tool calls in the SAME request are processed + return [] + + return [message] + + def apply_prompt_caching(self, messages: list[Message]) -> None: + """Applies caching breakpoints to the messages. + + For new Anthropic API, we only need to mark the last user or tool message as cacheable. + """ + # NOTE: this is only needed for anthropic + for message in reversed(messages): + if message.role in ('user', 'tool'): + message.content[ + -1 + ].cache_prompt = True # Last item inside the message content + break \ No newline at end of file diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py new file mode 100644 index 000000000000..9414fabfdfc9 --- /dev/null +++ b/tests/unit/test_conversation_memory.py @@ -0,0 +1,105 @@ +import unittest +from unittest.mock import MagicMock, patch + +from openhands.controller.state.state import State +from openhands.core.message import Message, TextContent +from openhands.events.action import MessageAction +from openhands.events.observation import CmdOutputObservation +from openhands.memory.conversation_memory import ConversationMemory +from openhands.utils.prompt import PromptManager + + +class TestConversationMemory(unittest.TestCase): + def setUp(self): + self.prompt_manager = MagicMock(spec=PromptManager) + self.prompt_manager.get_system_message.return_value = "System message" + self.conversation_memory = ConversationMemory(self.prompt_manager) + self.state = MagicMock(spec=State) + self.state.history = [] + + def test_process_initial_messages(self): + messages = self.conversation_memory.process_initial_messages(with_caching=False) + self.assertEqual(len(messages), 1) + self.assertEqual(messages[0].role, "system") + self.assertEqual(messages[0].content[0].text, "System message") + self.assertEqual(messages[0].content[0].cache_prompt, False) + + messages = self.conversation_memory.process_initial_messages(with_caching=True) + self.assertEqual(messages[0].content[0].cache_prompt, True) + + def test_process_events_with_message_action(self): + user_message = MessageAction(content="Hello", source="user") + assistant_message = MessageAction(content="Hi there", source="assistant") + + initial_messages = [ + Message( + role="system", + content=[TextContent(text="System message")] + ) + ] + + messages = self.conversation_memory.process_events( + state=self.state, + condensed_history=[user_message, assistant_message], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False + ) + + self.assertEqual(len(messages), 3) + self.assertEqual(messages[0].role, "system") + self.assertEqual(messages[1].role, "user") + self.assertEqual(messages[1].content[0].text, "Hello") + self.assertEqual(messages[2].role, "assistant") + self.assertEqual(messages[2].content[0].text, "Hi there") + + def test_process_events_with_observation(self): + user_message = MessageAction(content="Hello", source="user") + cmd_output = CmdOutputObservation( + command="ls", + exit_code=0, + output="file1.txt\nfile2.txt", + tool_call_metadata=None + ) + + initial_messages = [ + Message( + role="system", + content=[TextContent(text="System message")] + ) + ] + + messages = self.conversation_memory.process_events( + state=self.state, + condensed_history=[user_message, cmd_output], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False + ) + + self.assertEqual(len(messages), 3) + self.assertEqual(messages[0].role, "system") + self.assertEqual(messages[1].role, "user") + self.assertEqual(messages[2].role, "user") + self.assertIn("Observed result of command executed by user", messages[2].content[0].text) + self.assertIn("file1.txt", messages[2].content[0].text) + + def test_apply_prompt_caching(self): + messages = [ + Message(role="system", content=[TextContent(text="System message")]), + Message(role="user", content=[TextContent(text="User message")]), + Message(role="assistant", content=[TextContent(text="Assistant message")]), + Message(role="user", content=[TextContent(text="Another user message")]), + ] + + self.conversation_memory.apply_prompt_caching(messages) + + # Only the last user message should have cache_prompt=True + self.assertFalse(messages[0].content[0].cache_prompt) + self.assertFalse(messages[1].content[0].cache_prompt) + self.assertFalse(messages[2].content[0].cache_prompt) + self.assertTrue(messages[3].content[0].cache_prompt) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 10098d9556c5c12e25d494e0c95e369232a222fb Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 19:32:23 +0000 Subject: [PATCH 02/15] Fix formatting issues --- .../agenthub/codeact_agent/codeact_agent.py | 2 +- openhands/memory/conversation_memory.py | 6 +- openhands/security/invariant/analyzer.py | 2 +- openhands/security/invariant/nodes.py | 7 +- pyproject.toml | 2 + tests/unit/test_conversation_memory.py | 84 +++++++++---------- 6 files changed, 51 insertions(+), 52 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 6b707aecb258..686a6f12b324 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -8,7 +8,7 @@ from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.core.logger import openhands_logger as logger -from openhands.core.message import Message, TextContent +from openhands.core.message import Message from openhands.events.action import ( Action, AgentFinishAction, diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 3653e7ba70c2..7fbd23c89e9f 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -1,5 +1,3 @@ -import json - from litellm import ModelResponse from openhands.controller.state.state import State @@ -32,7 +30,7 @@ from openhands.events.observation.error import ErrorObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import truncate_content -from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo +from openhands.utils.prompt import PromptManager class ConversationMemory: @@ -405,4 +403,4 @@ def apply_prompt_caching(self, messages: list[Message]) -> None: message.content[ -1 ].cache_prompt = True # Last item inside the message content - break \ No newline at end of file + break diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 540a9341b822..25afcbec5133 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -310,7 +310,7 @@ async def security_risk(self, event: Action) -> ActionSecurityRisk: check_result = self.monitor.check(self.input, input) self.input.extend(input) risk = ActionSecurityRisk.UNKNOWN - + if isinstance(check_result, tuple): result, err = check_result if err: diff --git a/openhands/security/invariant/nodes.py b/openhands/security/invariant/nodes.py index c3d7b9713bea..ac294622fb8f 100644 --- a/openhands/security/invariant/nodes.py +++ b/openhands/security/invariant/nodes.py @@ -1,4 +1,5 @@ -from typing import Any, Iterable, Tuple +from typing import Any, Iterable + from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass @@ -31,7 +32,9 @@ class Message(Event): content: str | None tool_calls: list[ToolCall] | None = None - def __rich_repr__(self) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]: + def __rich_repr__( + self, + ) -> Iterable[Any | tuple[Any] | tuple[str, Any] | tuple[str, Any, Any]]: # Print on separate line yield 'role', self.role yield 'content', self.content diff --git a/pyproject.toml b/pyproject.toml index 0b79dca0994a..b41f1571480e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] + [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" @@ -137,6 +138,7 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" + [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 9414fabfdfc9..cfdd368b2a96 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from openhands.controller.state.state import State from openhands.core.message import Message, TextContent @@ -12,7 +12,7 @@ class TestConversationMemory(unittest.TestCase): def setUp(self): self.prompt_manager = MagicMock(spec=PromptManager) - self.prompt_manager.get_system_message.return_value = "System message" + self.prompt_manager.get_system_message.return_value = 'System message' self.conversation_memory = ConversationMemory(self.prompt_manager) self.state = MagicMock(spec=State) self.state.history = [] @@ -20,80 +20,76 @@ def setUp(self): def test_process_initial_messages(self): messages = self.conversation_memory.process_initial_messages(with_caching=False) self.assertEqual(len(messages), 1) - self.assertEqual(messages[0].role, "system") - self.assertEqual(messages[0].content[0].text, "System message") + self.assertEqual(messages[0].role, 'system') + self.assertEqual(messages[0].content[0].text, 'System message') self.assertEqual(messages[0].content[0].cache_prompt, False) messages = self.conversation_memory.process_initial_messages(with_caching=True) self.assertEqual(messages[0].content[0].cache_prompt, True) def test_process_events_with_message_action(self): - user_message = MessageAction(content="Hello", source="user") - assistant_message = MessageAction(content="Hi there", source="assistant") - + user_message = MessageAction(content='Hello', source='user') + assistant_message = MessageAction(content='Hi there', source='assistant') + initial_messages = [ - Message( - role="system", - content=[TextContent(text="System message")] - ) + Message(role='system', content=[TextContent(text='System message')]) ] - + messages = self.conversation_memory.process_events( state=self.state, condensed_history=[user_message, assistant_message], initial_messages=initial_messages, max_message_chars=None, - vision_is_active=False + vision_is_active=False, ) - + self.assertEqual(len(messages), 3) - self.assertEqual(messages[0].role, "system") - self.assertEqual(messages[1].role, "user") - self.assertEqual(messages[1].content[0].text, "Hello") - self.assertEqual(messages[2].role, "assistant") - self.assertEqual(messages[2].content[0].text, "Hi there") + self.assertEqual(messages[0].role, 'system') + self.assertEqual(messages[1].role, 'user') + self.assertEqual(messages[1].content[0].text, 'Hello') + self.assertEqual(messages[2].role, 'assistant') + self.assertEqual(messages[2].content[0].text, 'Hi there') def test_process_events_with_observation(self): - user_message = MessageAction(content="Hello", source="user") + user_message = MessageAction(content='Hello', source='user') cmd_output = CmdOutputObservation( - command="ls", + command='ls', exit_code=0, - output="file1.txt\nfile2.txt", - tool_call_metadata=None + output='file1.txt\nfile2.txt', + tool_call_metadata=None, ) - + initial_messages = [ - Message( - role="system", - content=[TextContent(text="System message")] - ) + Message(role='system', content=[TextContent(text='System message')]) ] - + messages = self.conversation_memory.process_events( state=self.state, condensed_history=[user_message, cmd_output], initial_messages=initial_messages, max_message_chars=None, - vision_is_active=False + vision_is_active=False, ) - + self.assertEqual(len(messages), 3) - self.assertEqual(messages[0].role, "system") - self.assertEqual(messages[1].role, "user") - self.assertEqual(messages[2].role, "user") - self.assertIn("Observed result of command executed by user", messages[2].content[0].text) - self.assertIn("file1.txt", messages[2].content[0].text) + self.assertEqual(messages[0].role, 'system') + self.assertEqual(messages[1].role, 'user') + self.assertEqual(messages[2].role, 'user') + self.assertIn( + 'Observed result of command executed by user', messages[2].content[0].text + ) + self.assertIn('file1.txt', messages[2].content[0].text) def test_apply_prompt_caching(self): messages = [ - Message(role="system", content=[TextContent(text="System message")]), - Message(role="user", content=[TextContent(text="User message")]), - Message(role="assistant", content=[TextContent(text="Assistant message")]), - Message(role="user", content=[TextContent(text="Another user message")]), + Message(role='system', content=[TextContent(text='System message')]), + Message(role='user', content=[TextContent(text='User message')]), + Message(role='assistant', content=[TextContent(text='Assistant message')]), + Message(role='user', content=[TextContent(text='Another user message')]), ] - + self.conversation_memory.apply_prompt_caching(messages) - + # Only the last user message should have cache_prompt=True self.assertFalse(messages[0].content[0].cache_prompt) self.assertFalse(messages[1].content[0].cache_prompt) @@ -101,5 +97,5 @@ def test_apply_prompt_caching(self): self.assertTrue(messages[3].content[0].cache_prompt) -if __name__ == "__main__": - unittest.main() \ No newline at end of file +if __name__ == '__main__': + unittest.main() From d90bc15a8527d6e4d6900ad8ade0b5482da90057 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 19:36:24 +0000 Subject: [PATCH 03/15] Convert test_conversation_memory.py to use pytest instead of unittest --- tests/unit/test_conversation_memory.py | 186 +++++++++++++------------ 1 file changed, 95 insertions(+), 91 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index cfdd368b2a96..5decb141f7f5 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,6 +1,7 @@ -import unittest from unittest.mock import MagicMock +import pytest + from openhands.controller.state.state import State from openhands.core.message import Message, TextContent from openhands.events.action import MessageAction @@ -9,93 +10,96 @@ from openhands.utils.prompt import PromptManager -class TestConversationMemory(unittest.TestCase): - def setUp(self): - self.prompt_manager = MagicMock(spec=PromptManager) - self.prompt_manager.get_system_message.return_value = 'System message' - self.conversation_memory = ConversationMemory(self.prompt_manager) - self.state = MagicMock(spec=State) - self.state.history = [] - - def test_process_initial_messages(self): - messages = self.conversation_memory.process_initial_messages(with_caching=False) - self.assertEqual(len(messages), 1) - self.assertEqual(messages[0].role, 'system') - self.assertEqual(messages[0].content[0].text, 'System message') - self.assertEqual(messages[0].content[0].cache_prompt, False) - - messages = self.conversation_memory.process_initial_messages(with_caching=True) - self.assertEqual(messages[0].content[0].cache_prompt, True) - - def test_process_events_with_message_action(self): - user_message = MessageAction(content='Hello', source='user') - assistant_message = MessageAction(content='Hi there', source='assistant') - - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - - messages = self.conversation_memory.process_events( - state=self.state, - condensed_history=[user_message, assistant_message], - initial_messages=initial_messages, - max_message_chars=None, - vision_is_active=False, - ) - - self.assertEqual(len(messages), 3) - self.assertEqual(messages[0].role, 'system') - self.assertEqual(messages[1].role, 'user') - self.assertEqual(messages[1].content[0].text, 'Hello') - self.assertEqual(messages[2].role, 'assistant') - self.assertEqual(messages[2].content[0].text, 'Hi there') - - def test_process_events_with_observation(self): - user_message = MessageAction(content='Hello', source='user') - cmd_output = CmdOutputObservation( - command='ls', - exit_code=0, - output='file1.txt\nfile2.txt', - tool_call_metadata=None, - ) - - initial_messages = [ - Message(role='system', content=[TextContent(text='System message')]) - ] - - messages = self.conversation_memory.process_events( - state=self.state, - condensed_history=[user_message, cmd_output], - initial_messages=initial_messages, - max_message_chars=None, - vision_is_active=False, - ) - - self.assertEqual(len(messages), 3) - self.assertEqual(messages[0].role, 'system') - self.assertEqual(messages[1].role, 'user') - self.assertEqual(messages[2].role, 'user') - self.assertIn( - 'Observed result of command executed by user', messages[2].content[0].text - ) - self.assertIn('file1.txt', messages[2].content[0].text) - - def test_apply_prompt_caching(self): - messages = [ - Message(role='system', content=[TextContent(text='System message')]), - Message(role='user', content=[TextContent(text='User message')]), - Message(role='assistant', content=[TextContent(text='Assistant message')]), - Message(role='user', content=[TextContent(text='Another user message')]), - ] - - self.conversation_memory.apply_prompt_caching(messages) - - # Only the last user message should have cache_prompt=True - self.assertFalse(messages[0].content[0].cache_prompt) - self.assertFalse(messages[1].content[0].cache_prompt) - self.assertFalse(messages[2].content[0].cache_prompt) - self.assertTrue(messages[3].content[0].cache_prompt) - - -if __name__ == '__main__': - unittest.main() +@pytest.fixture +def conversation_memory(): + prompt_manager = MagicMock(spec=PromptManager) + prompt_manager.get_system_message.return_value = 'System message' + return ConversationMemory(prompt_manager) + + +@pytest.fixture +def mock_state(): + state = MagicMock(spec=State) + state.history = [] + return state + + +def test_process_initial_messages(conversation_memory): + messages = conversation_memory.process_initial_messages(with_caching=False) + assert len(messages) == 1 + assert messages[0].role == 'system' + assert messages[0].content[0].text == 'System message' + assert messages[0].content[0].cache_prompt is False + + messages = conversation_memory.process_initial_messages(with_caching=True) + assert messages[0].content[0].cache_prompt is True + + +def test_process_events_with_message_action(conversation_memory, mock_state): + user_message = MessageAction(content='Hello', source='user') + assistant_message = MessageAction(content='Hi there', source='assistant') + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[user_message, assistant_message], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 3 + assert messages[0].role == 'system' + assert messages[1].role == 'user' + assert messages[1].content[0].text == 'Hello' + assert messages[2].role == 'assistant' + assert messages[2].content[0].text == 'Hi there' + + +def test_process_events_with_observation(conversation_memory, mock_state): + user_message = MessageAction(content='Hello', source='user') + cmd_output = CmdOutputObservation( + command='ls', + exit_code=0, + output='file1.txt\nfile2.txt', + tool_call_metadata=None, + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[user_message, cmd_output], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 3 + assert messages[0].role == 'system' + assert messages[1].role == 'user' + assert messages[2].role == 'user' + assert 'Observed result of command executed by user' in messages[2].content[0].text + assert 'file1.txt' in messages[2].content[0].text + + +def test_apply_prompt_caching(conversation_memory): + messages = [ + Message(role='system', content=[TextContent(text='System message')]), + Message(role='user', content=[TextContent(text='User message')]), + Message(role='assistant', content=[TextContent(text='Assistant message')]), + Message(role='user', content=[TextContent(text='Another user message')]), + ] + + conversation_memory.apply_prompt_caching(messages) + + # Only the last user message should have cache_prompt=True + assert messages[0].content[0].cache_prompt is False + assert messages[1].content[0].cache_prompt is False + assert messages[2].content[0].cache_prompt is False + assert messages[3].content[0].cache_prompt is True From 9540757d98e22c913aafde989b9aaecd468c58f8 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 19:48:36 +0000 Subject: [PATCH 04/15] Expand test_conversation_memory.py with tests ported from test_message_utils.py --- tests/unit/test_conversation_memory.py | 394 ++++++++++++++++++++++++- 1 file changed, 378 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 5decb141f7f5..dc3382bcaaa6 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,11 +1,26 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest from openhands.controller.state.state import State -from openhands.core.message import Message, TextContent -from openhands.events.action import MessageAction +from openhands.core.message import ImageContent, Message, TextContent +from openhands.events.action import ( + AgentFinishAction, + CmdRunAction, + MessageAction, +) +from openhands.events.event import EventSource, FileEditSource, FileReadSource from openhands.events.observation import CmdOutputObservation +from openhands.events.observation.browse import BrowserOutputObservation +from openhands.events.observation.commands import ( + CmdOutputMetadata, + IPythonRunCellObservation, +) +from openhands.events.observation.delegate import AgentDelegateObservation +from openhands.events.observation.error import ErrorObservation +from openhands.events.observation.files import FileEditObservation, FileReadObservation +from openhands.events.observation.reject import UserRejectObservation +from openhands.events.tool import ToolCallMetadata from openhands.memory.conversation_memory import ConversationMemory from openhands.utils.prompt import PromptManager @@ -59,13 +74,272 @@ def test_process_events_with_message_action(conversation_memory, mock_state): assert messages[2].content[0].text == 'Hi there' -def test_process_events_with_observation(conversation_memory, mock_state): - user_message = MessageAction(content='Hello', source='user') - cmd_output = CmdOutputObservation( - command='ls', +def test_process_events_with_cmd_output_observation(conversation_memory, mock_state): + obs = CmdOutputObservation( + command='echo hello', + content='Command output', + metadata=CmdOutputMetadata( + exit_code=0, + prefix='[THIS IS PREFIX]', + suffix='[THIS IS SUFFIX]', + ), + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Observed result of command executed by user:' in result.content[0].text + assert '[Command finished with exit code 0]' in result.content[0].text + assert '[THIS IS PREFIX]' in result.content[0].text + assert '[THIS IS SUFFIX]' in result.content[0].text + + +def test_process_events_with_ipython_run_cell_observation( + conversation_memory, mock_state +): + obs = IPythonRunCellObservation( + code='plt.plot()', + content='IPython output\n![image](data:image/png;base64,ABC123)', + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'IPython output' in result.content[0].text + assert ( + '![image](data:image/png;base64, ...) already displayed to user' + in result.content[0].text + ) + assert 'ABC123' not in result.content[0].text + + +def test_process_events_with_agent_delegate_observation( + conversation_memory, mock_state +): + obs = AgentDelegateObservation( + content='Content', outputs={'content': 'Delegated agent output'} + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Delegated agent output' in result.content[0].text + + +def test_process_events_with_error_observation(conversation_memory, mock_state): + obs = ErrorObservation('Error message') + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Error message' in result.content[0].text + assert 'Error occurred in processing last action' in result.content[0].text + + +def test_process_events_with_unknown_observation(conversation_memory, mock_state): + obs = Mock() + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + with pytest.raises(ValueError, match='Unknown observation type'): + conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + +def test_process_events_with_file_edit_observation(conversation_memory, mock_state): + obs = FileEditObservation( + path='/test/file.txt', + prev_exist=True, + old_content='old content', + new_content='new content', + content='diff content', + impl_source=FileEditSource.LLM_BASED_EDIT, + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert '[Existing file /test/file.txt is edited with' in result.content[0].text + + +def test_process_events_with_file_read_observation(conversation_memory, mock_state): + obs = FileReadObservation( + path='/test/file.txt', + content='File content', + impl_source=FileReadSource.DEFAULT, + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == 'File content' + + +def test_process_events_with_browser_output_observation( + conversation_memory, mock_state +): + obs = BrowserOutputObservation( + url='http://example.com', + trigger_by_action='browse', + screenshot='', + content='Page loaded', + error=False, + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert '[Current URL: http://example.com]' in result.content[0].text + + +def test_process_events_with_user_reject_observation(conversation_memory, mock_state): + obs = UserRejectObservation('Action rejected') + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[obs], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Action rejected' in result.content[0].text + assert '[Last action has been rejected by the user]' in result.content[0].text + + +def test_process_events_with_function_calling_observation( + conversation_memory, mock_state +): + mock_response = { + 'id': 'mock_id', + 'total_calls_in_response': 1, + 'choices': [{'message': {'content': 'Task completed'}}], + } + obs = CmdOutputObservation( + command='echo hello', + content='Command output', + command_id=1, exit_code=0, - output='file1.txt\nfile2.txt', - tool_call_metadata=None, + ) + obs.tool_call_metadata = ToolCallMetadata( + tool_call_id='123', + function_name='execute_bash', + model_response=mock_response, + total_calls_in_response=1, ) initial_messages = [ @@ -74,18 +348,106 @@ def test_process_events_with_observation(conversation_memory, mock_state): messages = conversation_memory.process_events( state=mock_state, - condensed_history=[user_message, cmd_output], + condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, vision_is_active=False, ) - assert len(messages) == 3 - assert messages[0].role == 'system' - assert messages[1].role == 'user' - assert messages[2].role == 'user' - assert 'Observed result of command executed by user' in messages[2].content[0].text - assert 'file1.txt' in messages[2].content[0].text + # No direct message when using function calling + assert len(messages) == 1 # Only the initial system message + + +def test_process_events_with_message_action_with_image(conversation_memory, mock_state): + action = MessageAction( + content='Message with image', + image_urls=['http://example.com/image.jpg'], + ) + action._source = EventSource.AGENT + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[action], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=True, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'assistant' + assert len(result.content) == 2 + assert isinstance(result.content[0], TextContent) + assert isinstance(result.content[1], ImageContent) + assert result.content[0].text == 'Message with image' + assert result.content[1].image_urls == ['http://example.com/image.jpg'] + + +def test_process_events_with_user_cmd_action(conversation_memory, mock_state): + action = CmdRunAction(command='ls -l') + action._source = EventSource.USER + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[action], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'user' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'User executed the command' in result.content[0].text + assert 'ls -l' in result.content[0].text + + +def test_process_events_with_agent_finish_action_with_tool_metadata( + conversation_memory, mock_state +): + mock_response = { + 'id': 'mock_id', + 'total_calls_in_response': 1, + 'choices': [{'message': {'content': 'Task completed'}}], + } + + action = AgentFinishAction(thought='Initial thought') + action._source = EventSource.AGENT + action.tool_call_metadata = ToolCallMetadata( + tool_call_id='123', + function_name='finish', + model_response=mock_response, + total_calls_in_response=1, + ) + + initial_messages = [ + Message(role='system', content=[TextContent(text='System message')]) + ] + + messages = conversation_memory.process_events( + state=mock_state, + condensed_history=[action], + initial_messages=initial_messages, + max_message_chars=None, + vision_is_active=False, + ) + + assert len(messages) == 2 + result = messages[1] + assert result.role == 'assistant' + assert len(result.content) == 1 + assert isinstance(result.content[0], TextContent) + assert 'Initial thought\nTask completed' in result.content[0].text def test_apply_prompt_caching(conversation_memory): From a574aa5e5a475e6029a34355206cc467dc8d5369 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 19:59:43 +0000 Subject: [PATCH 05/15] Clean up message_utils.py to keep only token usage functions --- openhands/core/message_utils.py | 367 ------------------------------- tests/unit/test_message_utils.py | 272 +---------------------- 2 files changed, 1 insertion(+), 638 deletions(-) diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py index edb8902f2c4d..4d0651250a81 100644 --- a/openhands/core/message_utils.py +++ b/openhands/core/message_utils.py @@ -1,374 +1,7 @@ -from litellm import ModelResponse - -from openhands.core.logger import openhands_logger as logger -from openhands.core.message import ImageContent, Message, TextContent -from openhands.core.schema import ActionType -from openhands.events.action import ( - Action, - AgentDelegateAction, - AgentFinishAction, - BrowseInteractiveAction, - BrowseURLAction, - CmdRunAction, - FileEditAction, - FileReadAction, - IPythonRunCellAction, - MessageAction, -) from openhands.events.event import Event -from openhands.events.observation import ( - AgentCondensationObservation, - AgentDelegateObservation, - BrowserOutputObservation, - CmdOutputObservation, - FileEditObservation, - FileReadObservation, - IPythonRunCellObservation, - UserRejectObservation, -) -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.observation import Observation -from openhands.events.serialization.event import truncate_content from openhands.llm.metrics import Metrics, TokenUsage -def events_to_messages( - events: list[Event], - max_message_chars: int | None = None, - vision_is_active: bool = False, - enable_som_visual_browsing: bool = False, -) -> list[Message]: - """Converts a list of events into a list of messages that can be sent to the LLM. - - Ensures that tool call actions are processed correctly in function calling mode. - - Args: - events: A list of events to convert. Each event can be an Action or Observation. - max_message_chars: The maximum number of characters in the content of an event included in the prompt to the LLM. - Larger observations are truncated. - vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included. - enable_som_visual_browsing: Whether to enable visual browsing for the SOM model. - """ - messages = [] - - pending_tool_call_action_messages: dict[str, Message] = {} - tool_call_id_to_message: dict[str, Message] = {} - - for event in events: - # create a regular message from an event - if isinstance(event, Action): - messages_to_add = get_action_message( - action=event, - pending_tool_call_action_messages=pending_tool_call_action_messages, - vision_is_active=vision_is_active, - ) - elif isinstance(event, Observation): - messages_to_add = get_observation_message( - obs=event, - tool_call_id_to_message=tool_call_id_to_message, - max_message_chars=max_message_chars, - vision_is_active=vision_is_active, - enable_som_visual_browsing=enable_som_visual_browsing, - ) - else: - raise ValueError(f'Unknown event type: {type(event)}') - - # Check pending tool call action messages and see if they are complete - _response_ids_to_remove = [] - for ( - response_id, - pending_message, - ) in pending_tool_call_action_messages.items(): - assert pending_message.tool_calls is not None, ( - 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' - f'Pending message: {pending_message}' - ) - if all( - tool_call.id in tool_call_id_to_message - for tool_call in pending_message.tool_calls - ): - # If complete: - # -- 1. Add the message that **initiated** the tool calls - messages_to_add.append(pending_message) - # -- 2. Add the tool calls **results*** - for tool_call in pending_message.tool_calls: - messages_to_add.append(tool_call_id_to_message[tool_call.id]) - tool_call_id_to_message.pop(tool_call.id) - _response_ids_to_remove.append(response_id) - # Cleanup the processed pending tool messages - for response_id in _response_ids_to_remove: - pending_tool_call_action_messages.pop(response_id) - - messages += messages_to_add - - return messages - - -def get_action_message( - action: Action, - pending_tool_call_action_messages: dict[str, Message], - vision_is_active: bool = False, -) -> list[Message]: - """Converts an action into a message format that can be sent to the LLM. - - This method handles different types of actions and formats them appropriately: - 1. For tool-based actions (AgentDelegate, CmdRun, IPythonRunCell, FileEdit) and agent-sourced AgentFinish: - - In function calling mode: Stores the LLM's response in pending_tool_call_action_messages - - In non-function calling mode: Creates a message with the action string - 2. For MessageActions: Creates a message with the text content and optional image content - - Args: - action: The action to convert. Can be one of: - - CmdRunAction: For executing bash commands - - IPythonRunCellAction: For running IPython code - - FileEditAction: For editing files - - FileReadAction: For reading files using openhands-aci commands - - BrowseInteractiveAction: For browsing the web - - AgentFinishAction: For ending the interaction - - MessageAction: For sending messages - - pending_tool_call_action_messages: Dictionary mapping response IDs to their corresponding messages. - Used in function calling mode to track tool calls that are waiting for their results. - - vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included - - Returns: - list[Message]: A list containing the formatted message(s) for the action. - May be empty if the action is handled as a tool call in function calling mode. - - Note: - In function calling mode, tool-based actions are stored in pending_tool_call_action_messages - rather than being returned immediately. They will be processed later when all corresponding - tool call results are available. - """ - # create a regular message from an event - if isinstance( - action, - ( - AgentDelegateAction, - IPythonRunCellAction, - FileEditAction, - FileReadAction, - BrowseInteractiveAction, - BrowseURLAction, - ), - ) or (isinstance(action, CmdRunAction) and action.source == 'agent'): - tool_metadata = action.tool_call_metadata - assert tool_metadata is not None, ( - 'Tool call metadata should NOT be None when function calling is enabled. Action: ' - + str(action) - ) - - llm_response: ModelResponse = tool_metadata.model_response - assistant_msg = getattr(llm_response.choices[0], 'message') - - # Add the LLM message (assistant) that initiated the tool calls - # (overwrites any previous message with the same response_id) - logger.debug( - f'Tool calls type: {type(assistant_msg.tool_calls)}, value: {assistant_msg.tool_calls}' - ) - pending_tool_call_action_messages[llm_response.id] = Message( - role=getattr(assistant_msg, 'role', 'assistant'), - # tool call content SHOULD BE a string - content=[TextContent(text=assistant_msg.content or '')] - if assistant_msg.content is not None - else [], - tool_calls=assistant_msg.tool_calls, - ) - return [] - elif isinstance(action, AgentFinishAction): - role = 'user' if action.source == 'user' else 'assistant' - - # when agent finishes, it has tool_metadata - # which has already been executed, and it doesn't have a response - # when the user finishes (/exit), we don't have tool_metadata - tool_metadata = action.tool_call_metadata - if tool_metadata is not None: - # take the response message from the tool call - assistant_msg = getattr(tool_metadata.model_response.choices[0], 'message') - content = assistant_msg.content or '' - - # save content if any, to thought - if action.thought: - if action.thought != content: - action.thought += '\n' + content - else: - action.thought = content - - # remove the tool call metadata - action.tool_call_metadata = None - if role not in ('user', 'system', 'assistant', 'tool'): - raise ValueError(f'Invalid role: {role}') - return [ - Message( - role=role, # type: ignore[arg-type] - content=[TextContent(text=action.thought)], - ) - ] - elif isinstance(action, MessageAction): - role = 'user' if action.source == 'user' else 'assistant' - content = [TextContent(text=action.content or '')] - if vision_is_active and action.image_urls: - content.append(ImageContent(image_urls=action.image_urls)) - if role not in ('user', 'system', 'assistant', 'tool'): - raise ValueError(f'Invalid role: {role}') - return [ - Message( - role=role, # type: ignore[arg-type] - content=content, - ) - ] - elif isinstance(action, CmdRunAction) and action.source == 'user': - content = [TextContent(text=f'User executed the command:\n{action.command}')] - return [ - Message( - role='user', # Always user for CmdRunAction - content=content, - ) - ] - return [] - - -def get_observation_message( - obs: Observation, - tool_call_id_to_message: dict[str, Message], - max_message_chars: int | None = None, - vision_is_active: bool = False, - enable_som_visual_browsing: bool = False, -) -> list[Message]: - """Converts an observation into a message format that can be sent to the LLM. - - This method handles different types of observations and formats them appropriately: - - CmdOutputObservation: Formats command execution results with exit codes - - IPythonRunCellObservation: Formats IPython cell execution results, replacing base64 images - - FileEditObservation: Formats file editing results - - FileReadObservation: Formats file reading results from openhands-aci - - AgentDelegateObservation: Formats results from delegated agent tasks - - ErrorObservation: Formats error messages from failed actions - - UserRejectObservation: Formats user rejection messages - - In function calling mode, observations with tool_call_metadata are stored in - tool_call_id_to_message for later processing instead of being returned immediately. - - Args: - obs: The observation to convert - tool_call_id_to_message: Dictionary mapping tool call IDs to their corresponding messages (used in function calling mode) - max_message_chars: The maximum number of characters in the content of an observation included in the prompt to the LLM - vision_is_active: Whether vision is active in the LLM. If True, image URLs will be included - enable_som_visual_browsing: Whether to enable visual browsing for the SOM model - - Returns: - list[Message]: A list containing the formatted message(s) for the observation. - May be empty if the observation is handled as a tool response in function calling mode. - - Raises: - ValueError: If the observation type is unknown - """ - message: Message - - if isinstance(obs, CmdOutputObservation): - # if it doesn't have tool call metadata, it was triggered by a user action - if obs.tool_call_metadata is None: - text = truncate_content( - f'\nObserved result of command executed by user:\n{obs.to_agent_observation()}', - max_message_chars, - ) - else: - text = truncate_content(obs.to_agent_observation(), max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, IPythonRunCellObservation): - text = obs.content - # replace base64 images with a placeholder - splitted = text.split('\n') - for i, line in enumerate(splitted): - if '![image](data:image/png;base64,' in line: - splitted[i] = ( - '![image](data:image/png;base64, ...) already displayed to user' - ) - text = '\n'.join(splitted) - text = truncate_content(text, max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, FileEditObservation): - text = truncate_content(str(obs), max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, FileReadObservation): - message = Message( - role='user', content=[TextContent(text=obs.content)] - ) # Content is already truncated by openhands-aci - elif isinstance(obs, BrowserOutputObservation): - text = obs.get_agent_obs_text() - if ( - obs.trigger_by_action == ActionType.BROWSE_INTERACTIVE - and obs.set_of_marks is not None - and len(obs.set_of_marks) > 0 - and enable_som_visual_browsing - and vision_is_active - ): - text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n' - message = Message( - role='user', - content=[ - TextContent(text=text), - ImageContent(image_urls=[obs.set_of_marks]), - ], - ) - else: - message = Message( - role='user', - content=[TextContent(text=text)], - ) - elif isinstance(obs, AgentDelegateObservation): - text = truncate_content( - obs.outputs['content'] if 'content' in obs.outputs else '', - max_message_chars, - ) - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, ErrorObservation): - text = truncate_content(obs.content, max_message_chars) - text += '\n[Error occurred in processing last action]' - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, UserRejectObservation): - text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) - text += '\n[Last action has been rejected by the user]' - message = Message(role='user', content=[TextContent(text=text)]) - elif isinstance(obs, AgentCondensationObservation): - text = truncate_content(obs.content, max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) - else: - # If an observation message is not returned, it will cause an error - # when the LLM tries to return the next message - raise ValueError(f'Unknown observation type: {type(obs)}') - - # Update the message as tool response properly - if (tool_call_metadata := obs.tool_call_metadata) is not None: - tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( - role='tool', - content=message.content, - tool_call_id=tool_call_metadata.tool_call_id, - name=tool_call_metadata.function_name, - ) - # No need to return the observation message - # because it will be added by get_action_message when all the corresponding - # tool calls in the SAME request are processed - return [] - - return [message] - - -def apply_prompt_caching(messages: list[Message]) -> None: - """Applies caching breakpoints to the messages. - - For new Anthropic API, we only need to mark the last user or tool message as cacheable. - """ - # NOTE: this is only needed for anthropic - for message in reversed(messages): - if message.role in ('user', 'tool'): - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content - break - - def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None: """ Returns at most one token usage record for the `model_response.id` in this event's diff --git a/tests/unit/test_message_utils.py b/tests/unit/test_message_utils.py index 0f3a189a9cd3..38166d314777 100644 --- a/tests/unit/test_message_utils.py +++ b/tests/unit/test_message_utils.py @@ -1,282 +1,12 @@ -from unittest.mock import Mock - -import pytest - -from openhands.core.message import ImageContent, TextContent from openhands.core.message_utils import ( - get_action_message, - get_observation_message, get_token_usage_for_event, get_token_usage_for_event_id, ) -from openhands.events.action import ( - AgentFinishAction, - CmdRunAction, - MessageAction, -) -from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource -from openhands.events.observation.browse import BrowserOutputObservation -from openhands.events.observation.commands import ( - CmdOutputMetadata, - CmdOutputObservation, - IPythonRunCellObservation, -) -from openhands.events.observation.delegate import AgentDelegateObservation -from openhands.events.observation.error import ErrorObservation -from openhands.events.observation.files import FileEditObservation, FileReadObservation -from openhands.events.observation.reject import UserRejectObservation +from openhands.events.event import Event from openhands.events.tool import ToolCallMetadata from openhands.llm.metrics import Metrics, TokenUsage -def test_cmd_output_observation_message(): - obs = CmdOutputObservation( - command='echo hello', - content='Command output', - metadata=CmdOutputMetadata( - exit_code=0, - prefix='[THIS IS PREFIX]', - suffix='[THIS IS SUFFIX]', - ), - ) - - tool_call_id_to_message = {} - results = get_observation_message( - obs, tool_call_id_to_message=tool_call_id_to_message - ) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Observed result of command executed by user:' in result.content[0].text - assert '[Command finished with exit code 0]' in result.content[0].text - assert '[THIS IS PREFIX]' in result.content[0].text - assert '[THIS IS SUFFIX]' in result.content[0].text - - -def test_ipython_run_cell_observation_message(): - obs = IPythonRunCellObservation( - code='plt.plot()', - content='IPython output\n![image](data:image/png;base64,ABC123)', - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'IPython output' in result.content[0].text - assert ( - '![image](data:image/png;base64, ...) already displayed to user' - in result.content[0].text - ) - assert 'ABC123' not in result.content[0].text - - -def test_agent_delegate_observation_message(): - obs = AgentDelegateObservation( - content='Content', outputs={'content': 'Delegated agent output'} - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Delegated agent output' in result.content[0].text - - -def test_error_observation_message(): - obs = ErrorObservation('Error message') - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Error message' in result.content[0].text - assert 'Error occurred in processing last action' in result.content[0].text - - -def test_unknown_observation_message(): - obs = Mock() - - with pytest.raises(ValueError, match='Unknown observation type'): - get_observation_message(obs, tool_call_id_to_message={}) - - -def test_file_edit_observation_message(): - obs = FileEditObservation( - path='/test/file.txt', - prev_exist=True, - old_content='old content', - new_content='new content', - content='diff content', - impl_source=FileEditSource.LLM_BASED_EDIT, - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert '[Existing file /test/file.txt is edited with' in result.content[0].text - - -def test_file_read_observation_message(): - obs = FileReadObservation( - path='/test/file.txt', - content='File content', - impl_source=FileReadSource.DEFAULT, - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == 'File content' - - -def test_browser_output_observation_message(): - obs = BrowserOutputObservation( - url='http://example.com', - trigger_by_action='browse', - screenshot='', - content='Page loaded', - error=False, - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert '[Current URL: http://example.com]' in result.content[0].text - - -def test_user_reject_observation_message(): - obs = UserRejectObservation('Action rejected') - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Action rejected' in result.content[0].text - assert '[Last action has been rejected by the user]' in result.content[0].text - - -def test_function_calling_observation_message(): - mock_response = { - 'id': 'mock_id', - 'total_calls_in_response': 1, - 'choices': [{'message': {'content': 'Task completed'}}], - } - obs = CmdOutputObservation( - command='echo hello', - content='Command output', - command_id=1, - exit_code=0, - ) - obs.tool_call_metadata = ToolCallMetadata( - tool_call_id='123', - function_name='execute_bash', - model_response=mock_response, - total_calls_in_response=1, - ) - - results = get_observation_message(obs, tool_call_id_to_message={}) - assert len(results) == 0 # No direct message when using function calling - - -def test_message_action_with_image(): - action = MessageAction( - content='Message with image', - image_urls=['http://example.com/image.jpg'], - ) - action._source = EventSource.AGENT - - results = get_action_message(action, {}, vision_is_active=True) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'assistant' - assert len(result.content) == 2 - assert isinstance(result.content[0], TextContent) - assert isinstance(result.content[1], ImageContent) - assert result.content[0].text == 'Message with image' - assert result.content[1].image_urls == ['http://example.com/image.jpg'] - - -def test_user_cmd_action_message(): - action = CmdRunAction(command='ls -l') - action._source = EventSource.USER - - results = get_action_message(action, {}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'user' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'User executed the command' in result.content[0].text - assert 'ls -l' in result.content[0].text - - -def test_agent_finish_action_with_tool_metadata(): - mock_response = { - 'id': 'mock_id', - 'total_calls_in_response': 1, - 'choices': [{'message': {'content': 'Task completed'}}], - } - - action = AgentFinishAction(thought='Initial thought') - action._source = EventSource.AGENT - action.tool_call_metadata = ToolCallMetadata( - tool_call_id='123', - function_name='finish', - model_response=mock_response, - total_calls_in_response=1, - ) - - results = get_action_message(action, {}) - assert len(results) == 1 - - result = results[0] - assert result is not None - assert result.role == 'assistant' - assert len(result.content) == 1 - assert isinstance(result.content[0], TextContent) - assert 'Initial thought\nTask completed' in result.content[0].text - - def test_get_token_usage_for_event(): """Test that we get the single matching usage record (if any) based on the event's model_response.id.""" metrics = Metrics(model_name='test-model') From 5f49e125ad4c265acc091042c4299ff5f7c258d9 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 20:14:09 +0000 Subject: [PATCH 06/15] Fix test failures in test_conversation_memory.py --- tests/unit/test_conversation_memory.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index dc3382bcaaa6..2537b45c4995 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -9,7 +9,7 @@ CmdRunAction, MessageAction, ) -from openhands.events.event import EventSource, FileEditSource, FileReadSource +from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource from openhands.events.observation import CmdOutputObservation from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( @@ -51,8 +51,10 @@ def test_process_initial_messages(conversation_memory): def test_process_events_with_message_action(conversation_memory, mock_state): - user_message = MessageAction(content='Hello', source='user') - assistant_message = MessageAction(content='Hi there', source='assistant') + user_message = MessageAction(content='Hello') + user_message._source = EventSource.USER + assistant_message = MessageAction(content='Hi there') + assistant_message._source = EventSource.AGENT initial_messages = [ Message(role='system', content=[TextContent(text='System message')]) @@ -193,7 +195,8 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): def test_process_events_with_unknown_observation(conversation_memory, mock_state): - obs = Mock() + # Create a mock that inherits from Event but not Action or Observation + obs = Mock(spec=Event) initial_messages = [ Message(role='system', content=[TextContent(text='System message')]) From 9986f12831e0f4e03f9ffa947ff39e8c604e6297 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 20:15:35 +0000 Subject: [PATCH 07/15] Use pytest-mock instead of unittest.mock --- tests/unit/test_conversation_memory.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 2537b45c4995..fed713d9af8a 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,6 +1,5 @@ -from unittest.mock import MagicMock, Mock - import pytest +from pytest_mock import MockerFixture from openhands.controller.state.state import State from openhands.core.message import ImageContent, Message, TextContent @@ -26,15 +25,15 @@ @pytest.fixture -def conversation_memory(): - prompt_manager = MagicMock(spec=PromptManager) +def conversation_memory(mocker: MockerFixture): + prompt_manager = mocker.MagicMock(spec=PromptManager) prompt_manager.get_system_message.return_value = 'System message' return ConversationMemory(prompt_manager) @pytest.fixture -def mock_state(): - state = MagicMock(spec=State) +def mock_state(mocker: MockerFixture): + state = mocker.MagicMock(spec=State) state.history = [] return state @@ -194,9 +193,9 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): assert 'Error occurred in processing last action' in result.content[0].text -def test_process_events_with_unknown_observation(conversation_memory, mock_state): +def test_process_events_with_unknown_observation(conversation_memory, mock_state, mocker: MockerFixture): # Create a mock that inherits from Event but not Action or Observation - obs = Mock(spec=Event) + obs = mocker.MagicMock(spec=Event) initial_messages = [ Message(role='system', content=[TextContent(text='System message')]) From 8fa5aec8d605871e73ee163ca3ec649a640166ef Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 20:41:43 +0000 Subject: [PATCH 08/15] Apply formatting fixes from pre-commit --- tests/unit/test_conversation_memory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index fed713d9af8a..93dc5da143af 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -193,7 +193,9 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): assert 'Error occurred in processing last action' in result.content[0].text -def test_process_events_with_unknown_observation(conversation_memory, mock_state, mocker: MockerFixture): +def test_process_events_with_unknown_observation( + conversation_memory, mock_state, mocker: MockerFixture +): # Create a mock that inherits from Event but not Action or Observation obs = mocker.MagicMock(spec=Event) From 8fe7af761ddf8554c389a482c6402ade598adcdc Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 27 Feb 2025 21:45:58 +0100 Subject: [PATCH 09/15] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b41f1571480e..bce66e81df27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,6 @@ reportlab = "*" [tool.coverage.run] concurrency = ["gevent"] - [tool.poetry.group.runtime.dependencies] jupyterlab = "*" notebook = "*" From 9233dfd4f98e7b49104f5edf47ec1405903da648 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 27 Feb 2025 21:46:17 +0100 Subject: [PATCH 10/15] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bce66e81df27..0b79dca0994a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,6 @@ ignore = ["D1"] [tool.ruff.lint.pydocstyle] convention = "google" - [tool.poetry.group.evaluation.dependencies] streamlit = "*" whatthepatch = "*" From fccb1bffc6f0909ed205af37fe3f221ce870a12a Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 20:56:43 +0000 Subject: [PATCH 11/15] Revert to using unittest.mock for consistency with the rest of the codebase --- tests/unit/test_conversation_memory.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 93dc5da143af..2537b45c4995 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -1,5 +1,6 @@ +from unittest.mock import MagicMock, Mock + import pytest -from pytest_mock import MockerFixture from openhands.controller.state.state import State from openhands.core.message import ImageContent, Message, TextContent @@ -25,15 +26,15 @@ @pytest.fixture -def conversation_memory(mocker: MockerFixture): - prompt_manager = mocker.MagicMock(spec=PromptManager) +def conversation_memory(): + prompt_manager = MagicMock(spec=PromptManager) prompt_manager.get_system_message.return_value = 'System message' return ConversationMemory(prompt_manager) @pytest.fixture -def mock_state(mocker: MockerFixture): - state = mocker.MagicMock(spec=State) +def mock_state(): + state = MagicMock(spec=State) state.history = [] return state @@ -193,11 +194,9 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): assert 'Error occurred in processing last action' in result.content[0].text -def test_process_events_with_unknown_observation( - conversation_memory, mock_state, mocker: MockerFixture -): +def test_process_events_with_unknown_observation(conversation_memory, mock_state): # Create a mock that inherits from Event but not Action or Observation - obs = mocker.MagicMock(spec=Event) + obs = Mock(spec=Event) initial_messages = [ Message(role='system', content=[TextContent(text='System message')]) From 4334845996c3975c24877879117a81db0ab44bbf Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 27 Feb 2025 21:11:38 +0000 Subject: [PATCH 12/15] Fix test_process_events_with_unknown_observation to match the correct error message --- tests/unit/test_conversation_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 2537b45c4995..74062b5bc699 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -202,7 +202,7 @@ def test_process_events_with_unknown_observation(conversation_memory, mock_state Message(role='system', content=[TextContent(text='System message')]) ] - with pytest.raises(ValueError, match='Unknown observation type'): + with pytest.raises(ValueError, match='Unknown event type'): conversation_memory.process_events( state=mock_state, condensed_history=[obs], From b5e0166817e558487542e4eb2a6a7dfea57d1589 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 27 Feb 2025 23:26:44 +0100 Subject: [PATCH 13/15] Update openhands/agenthub/codeact_agent/codeact_agent.py Co-authored-by: Calvin Smith --- openhands/agenthub/codeact_agent/codeact_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 686a6f12b324..d701f48e2e16 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -188,7 +188,6 @@ def _get_messages(self, state: State) -> list[Message]: messages = self._enhance_messages(messages) if self.llm.is_caching_prompt_active(): - # Use conversation_memory to apply caching instead of calling apply_prompt_caching directly self.conversation_memory.apply_prompt_caching(messages) return messages From e4fd68d56bcce365bc50eeff08798f58f5e8788a Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 27 Feb 2025 23:39:52 +0100 Subject: [PATCH 14/15] remove State --- openhands/agenthub/codeact_agent/codeact_agent.py | 5 ++++- openhands/memory/conversation_memory.py | 6 ------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index d701f48e2e16..5ceb5bd92c31 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -176,8 +176,11 @@ def _get_messages(self, state: State) -> list[Message]: # Condense the events from the state. events = self.condenser.condensed_history(state) + logger.debug( + f'Processing {len(events)} events from a total of {len(state.history)} events' + ) + messages = self.conversation_memory.process_events( - state=state, condensed_history=events, initial_messages=messages, max_message_chars=self.llm.config.max_message_chars, diff --git a/openhands/memory/conversation_memory.py b/openhands/memory/conversation_memory.py index 7fbd23c89e9f..584356a2a80e 100644 --- a/openhands/memory/conversation_memory.py +++ b/openhands/memory/conversation_memory.py @@ -1,6 +1,5 @@ from litellm import ModelResponse -from openhands.controller.state.state import State from openhands.core.logger import openhands_logger as logger from openhands.core.message import ImageContent, Message, TextContent from openhands.core.schema import ActionType @@ -41,7 +40,6 @@ def __init__(self, prompt_manager: PromptManager): def process_events( self, - state: State, condensed_history: list[Event], initial_messages: list[Message], max_message_chars: int | None = None, @@ -63,10 +61,6 @@ def process_events( """ events = condensed_history - logger.debug( - f'Processing {len(events)} events from a total of {len(state.history)} events' - ) - # Process special events first (system prompts, etc.) messages = initial_messages From 647aaa76b501a673df26bfe02ace608a1345851f Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Fri, 28 Feb 2025 01:22:52 +0100 Subject: [PATCH 15/15] fix tests --- tests/unit/test_conversation_memory.py | 50 ++++++++------------------ 1 file changed, 14 insertions(+), 36 deletions(-) diff --git a/tests/unit/test_conversation_memory.py b/tests/unit/test_conversation_memory.py index 74062b5bc699..7721354bdb21 100644 --- a/tests/unit/test_conversation_memory.py +++ b/tests/unit/test_conversation_memory.py @@ -50,7 +50,7 @@ def test_process_initial_messages(conversation_memory): assert messages[0].content[0].cache_prompt is True -def test_process_events_with_message_action(conversation_memory, mock_state): +def test_process_events_with_message_action(conversation_memory): user_message = MessageAction(content='Hello') user_message._source = EventSource.USER assistant_message = MessageAction(content='Hi there') @@ -61,7 +61,6 @@ def test_process_events_with_message_action(conversation_memory, mock_state): ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[user_message, assistant_message], initial_messages=initial_messages, max_message_chars=None, @@ -76,7 +75,7 @@ def test_process_events_with_message_action(conversation_memory, mock_state): assert messages[2].content[0].text == 'Hi there' -def test_process_events_with_cmd_output_observation(conversation_memory, mock_state): +def test_process_events_with_cmd_output_observation(conversation_memory): obs = CmdOutputObservation( command='echo hello', content='Command output', @@ -92,7 +91,6 @@ def test_process_events_with_cmd_output_observation(conversation_memory, mock_st ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -110,9 +108,7 @@ def test_process_events_with_cmd_output_observation(conversation_memory, mock_st assert '[THIS IS SUFFIX]' in result.content[0].text -def test_process_events_with_ipython_run_cell_observation( - conversation_memory, mock_state -): +def test_process_events_with_ipython_run_cell_observation(conversation_memory): obs = IPythonRunCellObservation( code='plt.plot()', content='IPython output\n![image](data:image/png;base64,ABC123)', @@ -123,7 +119,6 @@ def test_process_events_with_ipython_run_cell_observation( ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -143,9 +138,7 @@ def test_process_events_with_ipython_run_cell_observation( assert 'ABC123' not in result.content[0].text -def test_process_events_with_agent_delegate_observation( - conversation_memory, mock_state -): +def test_process_events_with_agent_delegate_observation(conversation_memory): obs = AgentDelegateObservation( content='Content', outputs={'content': 'Delegated agent output'} ) @@ -155,7 +148,6 @@ def test_process_events_with_agent_delegate_observation( ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -170,7 +162,7 @@ def test_process_events_with_agent_delegate_observation( assert 'Delegated agent output' in result.content[0].text -def test_process_events_with_error_observation(conversation_memory, mock_state): +def test_process_events_with_error_observation(conversation_memory): obs = ErrorObservation('Error message') initial_messages = [ @@ -178,7 +170,6 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -194,7 +185,7 @@ def test_process_events_with_error_observation(conversation_memory, mock_state): assert 'Error occurred in processing last action' in result.content[0].text -def test_process_events_with_unknown_observation(conversation_memory, mock_state): +def test_process_events_with_unknown_observation(conversation_memory): # Create a mock that inherits from Event but not Action or Observation obs = Mock(spec=Event) @@ -204,7 +195,6 @@ def test_process_events_with_unknown_observation(conversation_memory, mock_state with pytest.raises(ValueError, match='Unknown event type'): conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -212,7 +202,7 @@ def test_process_events_with_unknown_observation(conversation_memory, mock_state ) -def test_process_events_with_file_edit_observation(conversation_memory, mock_state): +def test_process_events_with_file_edit_observation(conversation_memory): obs = FileEditObservation( path='/test/file.txt', prev_exist=True, @@ -227,7 +217,6 @@ def test_process_events_with_file_edit_observation(conversation_memory, mock_sta ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -242,7 +231,7 @@ def test_process_events_with_file_edit_observation(conversation_memory, mock_sta assert '[Existing file /test/file.txt is edited with' in result.content[0].text -def test_process_events_with_file_read_observation(conversation_memory, mock_state): +def test_process_events_with_file_read_observation(conversation_memory): obs = FileReadObservation( path='/test/file.txt', content='File content', @@ -254,7 +243,6 @@ def test_process_events_with_file_read_observation(conversation_memory, mock_sta ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -269,9 +257,7 @@ def test_process_events_with_file_read_observation(conversation_memory, mock_sta assert result.content[0].text == 'File content' -def test_process_events_with_browser_output_observation( - conversation_memory, mock_state -): +def test_process_events_with_browser_output_observation(conversation_memory): obs = BrowserOutputObservation( url='http://example.com', trigger_by_action='browse', @@ -285,7 +271,6 @@ def test_process_events_with_browser_output_observation( ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -300,7 +285,7 @@ def test_process_events_with_browser_output_observation( assert '[Current URL: http://example.com]' in result.content[0].text -def test_process_events_with_user_reject_observation(conversation_memory, mock_state): +def test_process_events_with_user_reject_observation(conversation_memory): obs = UserRejectObservation('Action rejected') initial_messages = [ @@ -308,7 +293,6 @@ def test_process_events_with_user_reject_observation(conversation_memory, mock_s ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -324,9 +308,7 @@ def test_process_events_with_user_reject_observation(conversation_memory, mock_s assert '[Last action has been rejected by the user]' in result.content[0].text -def test_process_events_with_function_calling_observation( - conversation_memory, mock_state -): +def test_process_events_with_function_calling_observation(conversation_memory): mock_response = { 'id': 'mock_id', 'total_calls_in_response': 1, @@ -350,7 +332,6 @@ def test_process_events_with_function_calling_observation( ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[obs], initial_messages=initial_messages, max_message_chars=None, @@ -361,7 +342,7 @@ def test_process_events_with_function_calling_observation( assert len(messages) == 1 # Only the initial system message -def test_process_events_with_message_action_with_image(conversation_memory, mock_state): +def test_process_events_with_message_action_with_image(conversation_memory): action = MessageAction( content='Message with image', image_urls=['http://example.com/image.jpg'], @@ -373,7 +354,6 @@ def test_process_events_with_message_action_with_image(conversation_memory, mock ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[action], initial_messages=initial_messages, max_message_chars=None, @@ -390,7 +370,7 @@ def test_process_events_with_message_action_with_image(conversation_memory, mock assert result.content[1].image_urls == ['http://example.com/image.jpg'] -def test_process_events_with_user_cmd_action(conversation_memory, mock_state): +def test_process_events_with_user_cmd_action(conversation_memory): action = CmdRunAction(command='ls -l') action._source = EventSource.USER @@ -399,7 +379,6 @@ def test_process_events_with_user_cmd_action(conversation_memory, mock_state): ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[action], initial_messages=initial_messages, max_message_chars=None, @@ -416,7 +395,7 @@ def test_process_events_with_user_cmd_action(conversation_memory, mock_state): def test_process_events_with_agent_finish_action_with_tool_metadata( - conversation_memory, mock_state + conversation_memory, ): mock_response = { 'id': 'mock_id', @@ -438,7 +417,6 @@ def test_process_events_with_agent_finish_action_with_tool_metadata( ] messages = conversation_memory.process_events( - state=mock_state, condensed_history=[action], initial_messages=initial_messages, max_message_chars=None,