Skip to content

Commit

Permalink
refactor info in the first user message to a recalled observation
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst committed Feb 23, 2025
1 parent 21c2253 commit d596fd2
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 52 deletions.
4 changes: 1 addition & 3 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CodeActAgent(Agent):
### Overview
This agent implements the CodeAct idea ([paper](https://arxiv.org/abs/2402.01030), [tweet](https://twitter.com/xingyaow_/status/1754556835703751087)) that consolidates LLM agents **act**ions into a unified **code** action space for both *simplicity* and *performance* (see paper for more details).
This agent implements the CodeAct idea ([paper](https://arxiv.org/abs/2402.01030), [tweet](https://twitter.com/xingyaow_/status/1754556835703751087)) that consolidates LLM agents' **act**ions into a unified **code** action space for both *simplicity* and *performance* (see paper for more details).
The conceptual idea is illustrated below. At each turn, the agent can:
Expand Down Expand Up @@ -214,8 +214,6 @@ def _enhance_messages(self, messages: list[Message]) -> list[Message]:
is_first_message_handled = True
# compose the first user message with examples
self.prompt_manager.add_examples_to_initial_message(msg)
if self.config.enable_prompt_extensions:
self.prompt_manager.add_info_to_initial_message(msg)

results.append(msg)

Expand Down
6 changes: 2 additions & 4 deletions openhands/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_memory(

if agent.prompt_manager and runtime:
# sets available hosts
agent.prompt_manager.set_runtime_info(runtime)
mem.set_runtime_info(runtime.web_hosts)

# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = runtime.get_microagents_from_selected_repo(
Expand All @@ -89,9 +89,7 @@ def create_memory(
if selected_repository:
repo_directory = selected_repository.split('/')[1]
if repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)
mem.set_repository_info(selected_repository, repo_directory)
return mem


Expand Down
58 changes: 52 additions & 6 deletions openhands/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
RepoMicroAgent,
load_microagents_from_dir,
)
from openhands.utils.prompt import PromptManager
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo


class Memory:
Expand Down Expand Up @@ -43,12 +43,21 @@ def __init__(
self.repo_microagents: dict[str, RepoMicroAgent] = {}
self.knowledge_microagents: dict[str, KnowledgeMicroAgent] = {}

# Track whether we've seen the first user message
self._first_user_message_seen = False

# Store repository / runtime info to send them to the templating later
self.repository_info: RepositoryInfo | None = None
self.runtime_info: RuntimeInfo | None = None

# TODO: enable_prompt_extensions

def _load_global_microagents(self) -> None:
"""
Loads microagents from the global microagents_dir.
This is effectively what used to happen in PromptManager.
"""
repo_agents, knowledge_agents, _task_agents = load_microagents_from_dir(
repo_agents, knowledge_agents, _ = load_microagents_from_dir(
self.microagents_dir
)
for name, agent in knowledge_agents.items():
Expand All @@ -62,14 +71,51 @@ def _load_global_microagents(self) -> None:
if isinstance(agent, RepoMicroAgent):
self.repo_microagents[name] = agent

def set_repository_info(self, repo_name: str, repo_directory: str) -> None:
"""Store repository info so we can reference it in an observation."""
self.repository_info = RepositoryInfo(repo_name, repo_directory)
self.prompt_manager.set_repository_info(self.repository_info)

def set_runtime_info(self, runtime_hosts: dict[str, int]) -> None:
"""Store runtime info (web hosts, ports, etc.)."""
# e.g. { '127.0.0.1': 8080 }
self.runtime_info = RuntimeInfo(available_hosts=runtime_hosts)
self.prompt_manager.set_runtime_info(self.runtime_info)

def on_event(self, event: Event):
"""Handle an event from the event stream."""
if isinstance(event, MessageAction):
self.on_user_message_action(event)
if event.source == 'user':
# If this is the first user message, create and add a RecallObservation
# with info about repo and runtime.
if not self._first_user_message_seen:
self._first_user_message_seen = True
self._on_first_user_message(event)
# continue with the next handler, to include microagents if suitable for this user message
self._on_user_message_action(event)
elif isinstance(event, RecallAction):
self.on_recall_action(event)
self._on_recall_action(event)

def _on_first_user_message(self, event: MessageAction):
"""Create and add to the stream a RecallObservation carrying info about repo and runtime."""
# Build the same text that used to be appended to the first user message
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
for microagent in self.repo_microagents.values():
# We assume these are the repo instructions
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content

# Now wrap it in a RecallObservation, rather than altering the user message:
obs = RecallObservation(
content=self.prompt_manager.build_additional_info_text(repo_instructions)
)
self.event_stream.add_event(obs, EventSource.ENVIRONMENT)

def on_user_message_action(self, event: MessageAction):
def _on_user_message_action(self, event: MessageAction):
"""Replicates old microagent logic: if a microagent triggers on user text,
we embed it in an <extra_info> block and post a RecallObservation."""
if event.source != 'user':
Expand Down Expand Up @@ -102,7 +148,7 @@ def on_user_message_action(self, event: MessageAction):
obs, event.source if event.source else EventSource.ENVIRONMENT
)

def on_recall_action(self, event: RecallAction):
def _on_recall_action(self, event: RecallAction):
"""If a RecallAction explicitly arrives, handle it."""
assert isinstance(event, RecallAction)

Expand Down
6 changes: 2 additions & 4 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ async def _create_memory(

if agent.prompt_manager and self.runtime:
# sets available hosts
agent.prompt_manager.set_runtime_info(self.runtime)
mem.set_runtime_info(self.runtime.web_hosts)

# loads microagents from repo/.openhands/microagents
microagents: list[BaseMicroAgent] = await call_sync_from_async(
Expand All @@ -342,9 +342,7 @@ async def _create_memory(
if selected_repository:
repo_directory = selected_repository.split('/')[1]
if repo_directory:
agent.prompt_manager.set_repository_info(
selected_repository, repo_directory
)
mem.set_repository_info(selected_repository, repo_directory)
return mem

def _maybe_restore_state(self) -> State | None:
Expand Down
61 changes: 26 additions & 35 deletions openhands/utils/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from openhands.controller.state.state import State
from openhands.core.message import Message, TextContent
from openhands.microagent.microagent import RepoMicroAgent
from openhands.runtime.base import Runtime


@dataclass
Expand Down Expand Up @@ -58,7 +57,7 @@ class PromptManager:
This class is dedicated toloading and rendering prompts (system prompt, user prompt).
Attributes:
prompt_dir (str): Directory containing prompt templates.
prompt_dir: Directory containing prompt templates.
"""

def __init__(
Expand All @@ -84,23 +83,17 @@ def _load_template(self, template_name: str) -> Template:
def get_system_message(self) -> str:
return self.system_template.render().strip()

def set_runtime_info(self, runtime: Runtime) -> None:
self.runtime_info.available_hosts = runtime.web_hosts
def set_runtime_info(self, runtime_info: RuntimeInfo) -> None:
self.runtime_info = runtime_info

def set_repository_info(
self,
repo_name: str,
repo_directory: str,
) -> None:
"""Sets information about the GitHub repository that has been cloned.
def set_repository_info(self, repository_info: RepositoryInfo) -> None:
"""Stores info about a cloned repository for rendering the template.
Args:
repo_name: The name of the GitHub repository (e.g. 'owner/repo')
repo_directory: The directory where the repository has been cloned
repo_name: The name of the repository.
repo_directory: The directory of the repository.
"""
self.repository_info = RepositoryInfo(
repo_name=repo_name, repo_directory=repo_directory
)
self.repository_info = repository_info

def get_example_user_message(self) -> str:
"""This is the initial user message provided to the agent
Expand All @@ -127,31 +120,29 @@ def add_info_to_initial_message(
self,
message: Message,
) -> None:
"""Adds information about the repository and runtime to the initial user message.
Args:
message: The initial user message to add information to.
"""
repo_instructions = ''
assert (
len(self.repo_microagents) <= 1
), f'Expecting at most one repo microagent, but found {len(self.repo_microagents)}: {self.repo_microagents.keys()}'
for microagent in self.repo_microagents.values():
# We assume these are the repo instructions
if repo_instructions:
repo_instructions += '\n\n'
repo_instructions += microagent.content

additional_info = ADDITIONAL_INFO_TEMPLATE.render(
repository_instructions=repo_instructions,
Previously inserted the rendered template at the start of the user's first message.
If we've switched to using a separate RecallObservation in Memory, we can safely remove
or comment out the direct insertion code below—but we still keep the method for
scenarios where we want to read or manipulate the template output.
"""
# Old code that forcibly modified the user message:
#
# info_block = self.build_additional_info_text(repo_instructions)
# if info_block:
# message.content.insert(0, TextContent(text=info_block))
#
# Now we comment it out or remove to avoid "injecting" directly.
pass

def build_additional_info_text(self, repo_instructions: str = '') -> str:
"""Renders the ADDITIONAL_INFO_TEMPLATE with the stored repository/runtime info."""
return ADDITIONAL_INFO_TEMPLATE.render(
repository_info=self.repository_info,
repository_instructions=repo_instructions,
runtime_info=self.runtime_info,
).strip()

# Insert the new content at the start of the TextContent list
if additional_info:
message.content.insert(0, TextContent(text=additional_info))

def add_turns_left_reminder(self, messages: list[Message], state: State) -> None:
latest_user_message = next(
islice(
Expand Down

0 comments on commit d596fd2

Please sign in to comment.