Skip to content

Commit

Permalink
Fix first user message (#6471)
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst authored Jan 27, 2025
1 parent 6045349 commit 89c7bf5
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 61 deletions.
44 changes: 18 additions & 26 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,33 +433,14 @@ def _get_messages(self, state: State) -> list[Message]:
],
)
]
example_message = self.prompt_manager.get_example_user_message()
if example_message:
messages.append(
Message(
role='user',
content=[TextContent(text=example_message)],
cache_prompt=self.llm.is_caching_prompt_active(),
)
)

# Repository and runtime info
additional_info = self.prompt_manager.get_additional_info()
if self.config.enable_prompt_extensions and additional_info:
# only add these if prompt extension is enabled
messages.append(
Message(
role='user',
content=[TextContent(text=additional_info)],
)
)

pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}

# Condense the events from the state.
events = self.condenser.condensed_history(state)

is_first_message_handled = False
for event in events:
# create a regular message from an event
if isinstance(event, Action):
Expand Down Expand Up @@ -501,19 +482,30 @@ def _get_messages(self, state: State) -> list[Message]:
for response_id in _response_ids_to_remove:
pending_tool_call_action_messages.pop(response_id)

for message in messages_to_add:
if message:
if message.role == 'user':
self.prompt_manager.enhance_message(message)
messages.append(message)
for msg in messages_to_add:
if msg:
if msg.role == 'user' and not is_first_message_handled:
is_first_message_handled = True
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)

# and/or repo/runtime info
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)

# enhance the user message with additional context based on keywords matched
if msg.role == 'user':
self.prompt_manager.enhance_message(msg)

messages.append(msg)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
# following logic here:
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
breakpoints_remaining = 3 # remaining 1 for system/tool
for message in reversed(messages):
if message.role == 'user' or message.role == 'tool':
if message.role in ('user', 'tool'):
if breakpoints_remaining > 0:
message.content[
-1
Expand Down
2 changes: 1 addition & 1 deletion openhands/core/config/sandbox_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SandboxConfig(BaseModel):
This should be a JSON string that will be parsed into a dictionary.
"""

remote_runtime_api_url: str = Field(default='http://localhost:8000')
remote_runtime_api_url: str | None = Field(default='http://localhost:8000')
local_runtime_url: str = Field(default='http://localhost')
keep_runtime_alive: bool = Field(default=False)
rm_all_containers: bool = Field(default=False)
Expand Down
1 change: 0 additions & 1 deletion openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def index():
Running the updated file:
<function=execute_bash>
<parameter=command>
<parameter=command>
python3 app.py > server.log 2>&1 &
</parameter>
</function>
Expand Down
4 changes: 4 additions & 0 deletions openhands/runtime/impl/remote/remote_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(
'debug',
'Setting workspace_base is not supported in the remote runtime.',
)
if self.config.sandbox.remote_runtime_api_url is None:
raise ValueError(
'remote_runtime_api_url is required in the remote runtime.'
)

self.runtime_builder = RemoteRuntimeBuilder(
self.config.sandbox.remote_runtime_api_url,
Expand Down
58 changes: 37 additions & 21 deletions openhands/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,6 @@ def _load_template(self, template_name: str) -> Template:
def get_system_message(self) -> str:
return self.system_template.render().strip()

def get_additional_info(self) -> str:
"""Gets information about the repository and runtime.
This is used to inject information about the repository and runtime into the initial user message.
"""
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
for microagent in self.repo_microagents.values():
# We assume these are the repo instructions
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content

return ADDITIONAL_INFO_TEMPLATE.render(
repository_instructions=repo_instructions,
repository_info=self.repository_info,
runtime_info=self.runtime_info,
).strip()

def set_runtime_info(self, runtime: Runtime):
self.runtime_info.available_hosts = runtime.web_hosts

Expand Down Expand Up @@ -205,6 +184,43 @@ def enhance_message(self, message: Message) -> None:
micro_text += '\n</extra_info>'
message.content.append(TextContent(text=micro_text))

def add_examples_to_initial_message(self, message: Message) -> None:
"""Add example_message to the first user message."""
example_message = self.get_example_user_message() or None

# Insert it at the start of the TextContent list
if example_message:
message.content.insert(0, TextContent(text=example_message))

def add_info_to_initial_message(
self,
message: Message,
) -> None:
"""Adds information about the repository and runtime to the initial user message.
Args:
message: The initial user message to add information to.
"""
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
for microagent in self.repo_microagents.values():
# We assume these are the repo instructions
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content

additional_info = ADDITIONAL_INFO_TEMPLATE.render(
repository_instructions=repo_instructions,
repository_info=self.repository_info,
runtime_info=self.runtime_info,
).strip()

# Insert the new content at the start of the TextContent list
if additional_info:
message.content.insert(0, TextContent(text=additional_info))

def add_turns_left_reminder(self, messages: list[Message], state: State) -> None:
latest_user_message = next(
islice(
Expand Down
10 changes: 6 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 20 additions & 8 deletions tests/unit/test_prompt_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,16 @@ def test_prompt_manager_with_microagent(prompt_dir):
# Test with GitHub repo
manager.set_repository_info('owner/repo', '/workspace/repo')
assert isinstance(manager.get_system_message(), str)
additional_info = manager.get_additional_info()
assert '<REPOSITORY_INFO>' in additional_info
assert 'owner/repo' in additional_info
assert '/workspace/repo' in additional_info

# Adding things to the initial user message
initial_msg = Message(
role='user', content=[TextContent(text='Ask me what your task is.')]
)
manager.add_info_to_initial_message(initial_msg)
msg_content: str = initial_msg.content[0].text
assert '<REPOSITORY_INFO>' in msg_content
assert 'owner/repo' in msg_content
assert '/workspace/repo' in msg_content

assert isinstance(manager.get_example_user_message(), str)

Expand Down Expand Up @@ -101,13 +107,19 @@ def test_prompt_manager_template_rendering(prompt_dir):
assert manager.repository_info.repo_name == 'owner/repo'
system_msg = manager.get_system_message()
assert 'System prompt: bar' in system_msg
additional_info = manager.get_additional_info()
assert '<REPOSITORY_INFO>' in additional_info

# Initial user message should have repo info
initial_msg = Message(
role='user', content=[TextContent(text='Ask me what your task is.')]
)
manager.add_info_to_initial_message(initial_msg)
msg_content: str = initial_msg.content[0].text
assert '<REPOSITORY_INFO>' in msg_content
assert (
"At the user's request, repository owner/repo has been cloned to directory /workspace/repo."
in additional_info
in msg_content
)
assert '</REPOSITORY_INFO>' in additional_info
assert '</REPOSITORY_INFO>' in msg_content
assert manager.get_example_user_message() == 'User prompt: foo'

# Clean up temporary files
Expand Down

0 comments on commit 89c7bf5

Please sign in to comment.