From 89c7bf59a7e30f185c3da92b71748ce71d0d0bd2 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Mon, 27 Jan 2025 22:09:03 +0100 Subject: [PATCH] Fix first user message (#6471) --- .../agenthub/codeact_agent/codeact_agent.py | 44 ++++++-------- openhands/core/config/sandbox_config.py | 2 +- openhands/llm/fn_call_converter.py | 1 - .../runtime/impl/remote/remote_runtime.py | 4 ++ openhands/utils/prompt.py | 58 ++++++++++++------- poetry.lock | 10 ++-- tests/unit/test_prompt_manager.py | 28 ++++++--- 7 files changed, 86 insertions(+), 61 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index ecb756781abe..d2b5b35a735d 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -433,26 +433,6 @@ def _get_messages(self, state: State) -> list[Message]: ], ) ] - example_message = self.prompt_manager.get_example_user_message() - if example_message: - messages.append( - Message( - role='user', - content=[TextContent(text=example_message)], - cache_prompt=self.llm.is_caching_prompt_active(), - ) - ) - - # Repository and runtime info - additional_info = self.prompt_manager.get_additional_info() - if self.config.enable_prompt_extensions and additional_info: - # only add these if prompt extension is enabled - messages.append( - Message( - role='user', - content=[TextContent(text=additional_info)], - ) - ) pending_tool_call_action_messages: dict[str, Message] = {} tool_call_id_to_message: dict[str, Message] = {} @@ -460,6 +440,7 @@ def _get_messages(self, state: State) -> list[Message]: # Condense the events from the state. events = self.condenser.condensed_history(state) + is_first_message_handled = False for event in events: # create a regular message from an event if isinstance(event, Action): @@ -501,11 +482,22 @@ def _get_messages(self, state: State) -> list[Message]: for response_id in _response_ids_to_remove: pending_tool_call_action_messages.pop(response_id) - for message in messages_to_add: - if message: - if message.role == 'user': - self.prompt_manager.enhance_message(message) - messages.append(message) + for msg in messages_to_add: + if msg: + if msg.role == 'user' and not is_first_message_handled: + is_first_message_handled = True + # compose the first user message with examples + self.prompt_manager.add_examples_to_initial_message(msg) + + # and/or repo/runtime info + if self.config.enable_prompt_extensions: + self.prompt_manager.add_info_to_initial_message(msg) + + # enhance the user message with additional context based on keywords matched + if msg.role == 'user': + self.prompt_manager.enhance_message(msg) + + messages.append(msg) if self.llm.is_caching_prompt_active(): # NOTE: this is only needed for anthropic @@ -513,7 +505,7 @@ def _get_messages(self, state: State) -> list[Message]: # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 breakpoints_remaining = 3 # remaining 1 for system/tool for message in reversed(messages): - if message.role == 'user' or message.role == 'tool': + if message.role in ('user', 'tool'): if breakpoints_remaining > 0: message.content[ -1 diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index f5b984fec0b9..bd3d81f559ae 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -37,7 +37,7 @@ class SandboxConfig(BaseModel): This should be a JSON string that will be parsed into a dictionary. """ - remote_runtime_api_url: str = Field(default='http://localhost:8000') + remote_runtime_api_url: str | None = Field(default='http://localhost:8000') local_runtime_url: str = Field(default='http://localhost') keep_runtime_alive: bool = Field(default=False) rm_all_containers: bool = Field(default=False) diff --git a/openhands/llm/fn_call_converter.py b/openhands/llm/fn_call_converter.py index 16e761bae5ba..fb5bbc9b1db5 100644 --- a/openhands/llm/fn_call_converter.py +++ b/openhands/llm/fn_call_converter.py @@ -200,7 +200,6 @@ def index(): Running the updated file: - python3 app.py > server.log 2>&1 & diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index f0a9a7fb359d..57d8e8def8b8 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -68,6 +68,10 @@ def __init__( 'debug', 'Setting workspace_base is not supported in the remote runtime.', ) + if self.config.sandbox.remote_runtime_api_url is None: + raise ValueError( + 'remote_runtime_api_url is required in the remote runtime.' + ) self.runtime_builder = RemoteRuntimeBuilder( self.config.sandbox.remote_runtime_api_url, diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 1ffd4b8f117b..7fc5d4638238 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -135,27 +135,6 @@ def _load_template(self, template_name: str) -> Template: def get_system_message(self) -> str: return self.system_template.render().strip() - def get_additional_info(self) -> str: - """Gets information about the repository and runtime. - - This is used to inject information about the repository and runtime into the initial user message. - """ - repo_instructions = '' - assert ( - len(self.repo_microagents) <= 1 - ), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}' - for microagent in self.repo_microagents.values(): - # We assume these are the repo instructions - if repo_instructions: - repo_instructions += '\n\n' - repo_instructions += microagent.content - - return ADDITIONAL_INFO_TEMPLATE.render( - repository_instructions=repo_instructions, - repository_info=self.repository_info, - runtime_info=self.runtime_info, - ).strip() - def set_runtime_info(self, runtime: Runtime): self.runtime_info.available_hosts = runtime.web_hosts @@ -205,6 +184,43 @@ def enhance_message(self, message: Message) -> None: micro_text += '\n' message.content.append(TextContent(text=micro_text)) + def add_examples_to_initial_message(self, message: Message) -> None: + """Add example_message to the first user message.""" + example_message = self.get_example_user_message() or None + + # Insert it at the start of the TextContent list + if example_message: + message.content.insert(0, TextContent(text=example_message)) + + def add_info_to_initial_message( + self, + message: Message, + ) -> None: + """Adds information about the repository and runtime to the initial user message. + + Args: + message: The initial user message to add information to. + """ + repo_instructions = '' + assert ( + len(self.repo_microagents) <= 1 + ), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}' + for microagent in self.repo_microagents.values(): + # We assume these are the repo instructions + if repo_instructions: + repo_instructions += '\n\n' + repo_instructions += microagent.content + + additional_info = ADDITIONAL_INFO_TEMPLATE.render( + repository_instructions=repo_instructions, + repository_info=self.repository_info, + runtime_info=self.runtime_info, + ).strip() + + # Insert the new content at the start of the TextContent list + if additional_info: + message.content.insert(0, TextContent(text=additional_info)) + def add_turns_left_reminder(self, messages: list[Message], state: State) -> None: latest_user_message = next( islice( diff --git a/poetry.lock b/poetry.lock index 0083ee06ec6a..2edc47144275 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1312,6 +1312,7 @@ files = [ {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:761817a3377ef15ac23cd7834715081791d4ec77f9297ee694ca1ee9c2c7e5eb"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3c672a53c0fb4725a29c303be906d3c1fa99c32f58abe008a82705f9ee96f40b"}, {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:4ac4c9f37eba52cb6fbeaf5b59c152ea976726b865bd4cf87883a7e7006cc543"}, + {file = "cryptography-44.0.0-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:60eb32934076fa07e4316b7b2742fa52cbb190b42c2df2863dbc4230a0a9b385"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ed3534eb1090483c96178fcb0f8893719d96d5274dfde98aa6add34614e97c8e"}, {file = "cryptography-44.0.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f3f6fdfa89ee2d9d496e2c087cebef9d4fcbb0ad63c40e821b39f74bf48d9c5e"}, {file = "cryptography-44.0.0-cp37-abi3-win32.whl", hash = "sha256:eb33480f1bad5b78233b0ad3e1b0be21e8ef1da745d8d2aecbb20671658b9053"}, @@ -1322,6 +1323,7 @@ files = [ {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c5eb858beed7835e5ad1faba59e865109f3e52b3783b9ac21e7e47dc5554e289"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f53c2c87e0fb4b0c00fa9571082a057e37690a8f12233306161c8f4b819960b7"}, {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:9e6fc8a08e116fb7c7dd1f040074c9d7b51d74a8ea40d4df2fc7aa08b76b9e6c"}, + {file = "cryptography-44.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:9abcc2e083cbe8dde89124a47e5e53ec38751f0d7dfd36801008f316a127d7ba"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:d2436114e46b36d00f8b72ff57e598978b37399d2786fd39793c36c6d5cb1c64"}, {file = "cryptography-44.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a01956ddfa0a6790d594f5b34fc1bfa6098aca434696a03cfdbe469b8ed79285"}, {file = "cryptography-44.0.0-cp39-abi3-win32.whl", hash = "sha256:eca27345e1214d1b9f9490d200f9db5a874479be914199194e746c893788d417"}, @@ -3900,13 +3902,13 @@ types-tqdm = "*" [[package]] name = "litellm" -version = "1.59.0" +version = "1.59.8" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.59.0-py3-none-any.whl", hash = "sha256:b0c8bdee556d5dc2f9c703f7dc831574ea2e339d2e762dd626d014c170b8b587"}, - {file = "litellm-1.59.0.tar.gz", hash = "sha256:140eecb47952558414d00f7a259fe303fe5f0d073973a28f488fc6938cc45660"}, + {file = "litellm-1.59.8-py3-none-any.whl", hash = "sha256:2473914bd2343485a185dfe7eedb12ee5fda32da3c9d9a8b73f6966b9b20cf39"}, + {file = "litellm-1.59.8.tar.gz", hash = "sha256:9d645cc4460f6a9813061f07086648c4c3d22febc8e1f21c663f2b7750d90512"}, ] [package.dependencies] diff --git a/tests/unit/test_prompt_manager.py b/tests/unit/test_prompt_manager.py index 46f1f5a254a1..7277c025b8a8 100644 --- a/tests/unit/test_prompt_manager.py +++ b/tests/unit/test_prompt_manager.py @@ -59,10 +59,16 @@ def test_prompt_manager_with_microagent(prompt_dir): # Test with GitHub repo manager.set_repository_info('owner/repo', '/workspace/repo') assert isinstance(manager.get_system_message(), str) - additional_info = manager.get_additional_info() - assert '' in additional_info - assert 'owner/repo' in additional_info - assert '/workspace/repo' in additional_info + + # Adding things to the initial user message + initial_msg = Message( + role='user', content=[TextContent(text='Ask me what your task is.')] + ) + manager.add_info_to_initial_message(initial_msg) + msg_content: str = initial_msg.content[0].text + assert '' in msg_content + assert 'owner/repo' in msg_content + assert '/workspace/repo' in msg_content assert isinstance(manager.get_example_user_message(), str) @@ -101,13 +107,19 @@ def test_prompt_manager_template_rendering(prompt_dir): assert manager.repository_info.repo_name == 'owner/repo' system_msg = manager.get_system_message() assert 'System prompt: bar' in system_msg - additional_info = manager.get_additional_info() - assert '' in additional_info + + # Initial user message should have repo info + initial_msg = Message( + role='user', content=[TextContent(text='Ask me what your task is.')] + ) + manager.add_info_to_initial_message(initial_msg) + msg_content: str = initial_msg.content[0].text + assert '' in msg_content assert ( "At the user's request, repository owner/repo has been cloned to directory /workspace/repo." - in additional_info + in msg_content ) - assert '' in additional_info + assert '' in msg_content assert manager.get_example_user_message() == 'User prompt: foo' # Clean up temporary files