Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trajectory replay: Fix a few corner cases #6380

Merged
merged 8 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ async def _step(self) -> None:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
action._source = EventSource.AGENT # type: ignore [attr-defined]
except (
LLMMalformedActionError,
LLMNoActionError,
Expand Down Expand Up @@ -720,7 +721,7 @@ async def _step(self) -> None:
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
self.event_stream.add_event(action, EventSource.AGENT)
self.event_stream.add_event(action, action._source) # type: ignore [attr-defined]

await self.update_state_after_step()

Expand Down
29 changes: 26 additions & 3 deletions openhands/controller/replay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.empty import NullObservation


class ReplayManager:
Expand All @@ -15,9 +17,31 @@ class ReplayManager:
initial state of the trajectory.
"""

def __init__(self, replay_events: list[Event] | None):
def __init__(self, events: list[Event] | None):
replay_events = []
for event in events or []:
if event.source == EventSource.ENVIRONMENT:
# ignore ENVIRONMENT events as they are not issued by
# the user or agent, and should not be replayed
continue
if isinstance(event, NullObservation):
# ignore NullObservation
continue
replay_events.append(event)

if replay_events:
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
logger.info(f'Replay events loaded, events length = {len(replay_events)}')
for index in range(len(replay_events) - 1):
event = replay_events[index]
if isinstance(event, MessageAction) and event.wait_for_response:
# For any message waiting for response that is not the last
# event, we override wait_for_response to False, as a response
# would have been included in the next event, and we don't
# want the user to interfere with the replay process
logger.info(
'Replay events contains wait_for_response message action, ignoring wait_for_response'
)
event.wait_for_response = False
self.replay_events = replay_events
self.replay_mode = bool(replay_events)
self.replay_index = 0
Expand All @@ -27,7 +51,6 @@ def _replayable(self) -> bool:
self.replay_events is not None
and self.replay_index < len(self.replay_events)
and isinstance(self.replay_events[self.replay_index], Action)
and self.replay_events[self.replay_index].source != EventSource.USER
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(not a 'review' comment, and this is cool anyway)
Just the other day, I was playing with Gemini-2.0-thinking, and it's been a lot of fun for coding-adjacent tasks! Among others, it explored a lot of openhands repo, tracked down every occurrence of oh_action and followed the execution flow up in frontend, downstream in backend, until it figured out everything about them. It makes itself mini-plans on the fly and does follow up, very cool!

Anyway, so in the server, all those are set with source USER, but they're quite different, e.g. agent change actions, prompt confirmations, CmdRunActions (ran by user in terminal), MessageActions. I think none should be a problem, and cmd run actions are good for replay! We do want to replay those, if we want to achieve a similar state (hopefully), and of course, they'd be in context.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think I did this check as a hack at the beginning - probably just to work around the wait_for_confirmation thing. It's been more and more clear that source USER events should be replayed too.

)

def should_replay(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
events = []
for item in data:
event = event_from_dict(item)
if event.source == EventSource.ENVIRONMENT:
# ignore ENVIRONMENT events as they are not issued by
# the user or agent, and should not be replayed
continue
# cannot add an event with _id to event stream
event._id = None # type: ignore[attr-defined]
events.append(event)
Expand Down
72 changes: 72 additions & 0 deletions tests/runtime/test_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.commands import CmdOutputObservation


Expand Down Expand Up @@ -46,6 +48,36 @@ def test_simple_replay(temp_dir, runtime_cls, run_as_openhands):
_close_test_runtime(runtime)


def test_simple_gui_replay(temp_dir, runtime_cls, run_as_openhands):
"""
A simple replay test that involves simple terminal operations and edits
(writing a Vue.js App), using the default agent

Note:
1. This trajectory is exported from GUI mode, meaning it has extra
environmental actions that don't appear in headless mode's trajectories
2. In GUI mode, agents typically don't finish; rather, they wait for the next
task from the user, so this exported trajectory ends with awaiting_user_input
"""
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)

config = _get_config('basic_gui_mode')

state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=NullAction(),
runtime=runtime,
# exit on message, otherwise this would be stuck on waiting for user input
exit_on_message=True,
)
)

assert state.agent_state == AgentState.FINISHED

_close_test_runtime(runtime)


def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
"""
Replay requires a consistent initial state to start with, otherwise it might
Expand Down Expand Up @@ -78,3 +110,43 @@ def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
assert has_error_in_action

_close_test_runtime(runtime)


def test_replay_basic_interactions(temp_dir, runtime_cls, run_as_openhands):
"""
Replay a trajectory that involves interactions, i.e. with user messages
in the middle. This tests two things:
1) The controller should be able to replay all actions without human
interference (no asking for user input).
2) The user messages in the trajectory should appear in the history.
"""
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)

config = _get_config('basic_interactions')

state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=NullAction(),
runtime=runtime,
)
)

assert state.agent_state == AgentState.FINISHED

# all user messages appear in the history, so that after a replay (assuming
# the trajectory doesn't end with `finish` action), LLM knows about all the
# context and can continue
user_messages = [
"what's 1+1?",
"No, I mean by Goldbach's conjecture!",
'Finish please',
]
i = 0
for event in state.history:
if isinstance(event, MessageAction) and event._source == EventSource.USER:
assert event.message == user_messages[i]
i += 1
assert i == len(user_messages)

_close_test_runtime(runtime)
Loading