Skip to content

Commit

Permalink
Trajectory replay: Fix a few corner cases (#6380)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-boxuan authored Feb 2, 2025
1 parent 62402cd commit e487008
Show file tree
Hide file tree
Showing 6 changed files with 863 additions and 4 deletions.
3 changes: 2 additions & 1 deletion openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ async def _step(self) -> None:
action = self.agent.step(self.state)
if action is None:
raise LLMNoActionError('No action was returned')
action._source = EventSource.AGENT # type: ignore [attr-defined]
except (
LLMMalformedActionError,
LLMNoActionError,
Expand Down Expand Up @@ -720,7 +721,7 @@ async def _step(self) -> None:
== ActionConfirmationStatus.AWAITING_CONFIRMATION
):
await self.set_agent_state_to(AgentState.AWAITING_USER_CONFIRMATION)
self.event_stream.add_event(action, EventSource.AGENT)
self.event_stream.add_event(action, action._source) # type: ignore [attr-defined]

await self.update_state_after_step()

Expand Down
29 changes: 26 additions & 3 deletions openhands/controller/replay.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.action import Action
from openhands.events.action.message import MessageAction
from openhands.events.event import Event, EventSource
from openhands.events.observation.empty import NullObservation


class ReplayManager:
Expand All @@ -15,9 +17,31 @@ class ReplayManager:
initial state of the trajectory.
"""

def __init__(self, replay_events: list[Event] | None):
def __init__(self, events: list[Event] | None):
replay_events = []
for event in events or []:
if event.source == EventSource.ENVIRONMENT:
# ignore ENVIRONMENT events as they are not issued by
# the user or agent, and should not be replayed
continue
if isinstance(event, NullObservation):
# ignore NullObservation
continue
replay_events.append(event)

if replay_events:
logger.info(f'Replay logs loaded, events length = {len(replay_events)}')
logger.info(f'Replay events loaded, events length = {len(replay_events)}')
for index in range(len(replay_events) - 1):
event = replay_events[index]
if isinstance(event, MessageAction) and event.wait_for_response:
# For any message waiting for response that is not the last
# event, we override wait_for_response to False, as a response
# would have been included in the next event, and we don't
# want the user to interfere with the replay process
logger.info(
'Replay events contains wait_for_response message action, ignoring wait_for_response'
)
event.wait_for_response = False
self.replay_events = replay_events
self.replay_mode = bool(replay_events)
self.replay_index = 0
Expand All @@ -27,7 +51,6 @@ def _replayable(self) -> bool:
self.replay_events is not None
and self.replay_index < len(self.replay_events)
and isinstance(self.replay_events[self.replay_index], Action)
and self.replay_events[self.replay_index].source != EventSource.USER
)

def should_replay(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
events = []
for item in data:
event = event_from_dict(item)
if event.source == EventSource.ENVIRONMENT:
# ignore ENVIRONMENT events as they are not issued by
# the user or agent, and should not be replayed
continue
# cannot add an event with _id to event stream
event._id = None # type: ignore[attr-defined]
events.append(event)
Expand Down
72 changes: 72 additions & 0 deletions tests/runtime/test_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from openhands.core.main import run_controller
from openhands.core.schema.agent import AgentState
from openhands.events.action.empty import NullAction
from openhands.events.action.message import MessageAction
from openhands.events.event import EventSource
from openhands.events.observation.commands import CmdOutputObservation


Expand Down Expand Up @@ -46,6 +48,36 @@ def test_simple_replay(temp_dir, runtime_cls, run_as_openhands):
_close_test_runtime(runtime)


def test_simple_gui_replay(temp_dir, runtime_cls, run_as_openhands):
"""
A simple replay test that involves simple terminal operations and edits
(writing a Vue.js App), using the default agent
Note:
1. This trajectory is exported from GUI mode, meaning it has extra
environmental actions that don't appear in headless mode's trajectories
2. In GUI mode, agents typically don't finish; rather, they wait for the next
task from the user, so this exported trajectory ends with awaiting_user_input
"""
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)

config = _get_config('basic_gui_mode')

state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=NullAction(),
runtime=runtime,
# exit on message, otherwise this would be stuck on waiting for user input
exit_on_message=True,
)
)

assert state.agent_state == AgentState.FINISHED

_close_test_runtime(runtime)


def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
"""
Replay requires a consistent initial state to start with, otherwise it might
Expand Down Expand Up @@ -78,3 +110,43 @@ def test_replay_wrong_initial_state(temp_dir, runtime_cls, run_as_openhands):
assert has_error_in_action

_close_test_runtime(runtime)


def test_replay_basic_interactions(temp_dir, runtime_cls, run_as_openhands):
"""
Replay a trajectory that involves interactions, i.e. with user messages
in the middle. This tests two things:
1) The controller should be able to replay all actions without human
interference (no asking for user input).
2) The user messages in the trajectory should appear in the history.
"""
runtime = _load_runtime(temp_dir, runtime_cls, run_as_openhands)

config = _get_config('basic_interactions')

state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=NullAction(),
runtime=runtime,
)
)

assert state.agent_state == AgentState.FINISHED

# all user messages appear in the history, so that after a replay (assuming
# the trajectory doesn't end with `finish` action), LLM knows about all the
# context and can continue
user_messages = [
"what's 1+1?",
"No, I mean by Goldbach's conjecture!",
'Finish please',
]
i = 0
for event in state.history:
if isinstance(event, MessageAction) and event._source == EventSource.USER:
assert event.message == user_messages[i]
i += 1
assert i == len(user_messages)

_close_test_runtime(runtime)
Loading

0 comments on commit e487008

Please sign in to comment.