From bd272beeb901e12550161e163af11c5a610b5350 Mon Sep 17 00:00:00 2001 From: Boxuan Li Date: Sun, 12 Jan 2025 22:15:11 -0800 Subject: [PATCH] Fix bugs --- openhands/controller/agent_controller.py | 18 +++++-- openhands/core/main.py | 65 ++++++++++++++++++++++-- openhands/core/setup.py | 42 ++------------- openhands/events/event.py | 4 +- openhands/events/observation/browse.py | 2 +- 5 files changed, 82 insertions(+), 49 deletions(-) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index d5e384c661dc..1fae77e71498 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -244,13 +244,20 @@ def should_replay(self) -> bool: if not self.replay_logs: # not in replay mode return False - while self.replay_index < len(self.replay_logs) and not isinstance( - self.replay_logs[self.replay_index], Action + + def replayable(index: int) -> bool: + return ( + self.replay_logs is not None + and index < len(self.replay_logs) + and isinstance(self.replay_logs[index], Action) + and self.replay_logs[index].source != EventSource.USER + ) + + while self.replay_index < len(self.replay_logs) and not replayable( + self.replay_index ): self.replay_index += 1 - return self.replay_index < len(self.replay_logs) and isinstance( - self.replay_logs[self.replay_index], Action - ) + return replayable(self.replay_index) def should_step(self, event: Event) -> bool: """ @@ -591,6 +598,7 @@ async def _step(self) -> None: event = self.replay_logs[self.replay_index] assert isinstance(event, Action) action = event + self.replay_index += 1 else: try: action = self.agent.step(self.state) diff --git a/openhands/core/main.py b/openhands/core/main.py index b27cac1e586d..65c00fb7e3db 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -2,6 +2,7 @@ import json import os import sys +from pathlib import Path from typing import Callable, Protocol import openhands.agenthub # noqa F401 (we import this to get the agents registered) @@ -22,10 +23,11 @@ generate_sid, ) from openhands.events import EventSource, EventStreamSubscriber -from openhands.events.action import MessageAction +from openhands.events.action import MessageAction, NullAction from openhands.events.action.action import Action from openhands.events.event import Event from openhands.events.observation import AgentStateChangedObservation +from openhands.events.serialization import event_from_dict from openhands.events.serialization.event import event_to_trajectory from openhands.runtime.base import Runtime @@ -101,7 +103,17 @@ async def run_controller( if agent is None: agent = create_agent(runtime, config) - controller, initial_state = create_controller(agent, runtime, config) + replay_logs: list[Event] | None = None + if config.replay_trajectory_path: + logger.info('Trajectory replay is enabled') + assert isinstance(initial_user_action, NullAction) + replay_logs, initial_user_action = load_replay_log( + config.replay_trajectory_path + ) + + controller, initial_state = create_controller( + agent, runtime, config, replay_logs=replay_logs + ) assert isinstance( initial_user_action, Action @@ -194,21 +206,64 @@ def auto_continue_response( return message +def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]: + """ + Load trajectory from given path, serialize it to a list of events, and return + two things: + 1) A list of events except the first action + 2) First action (user message, a.k.a. initial task) + """ + try: + path = Path(trajectory_path).resolve() + + if not path.exists(): + raise ValueError(f'Trajectory file not found: {path}') + + if not path.is_file(): + raise ValueError(f'Trajectory path is a directory, not a file: {path}') + + with open(path, 'r', encoding='utf-8') as file: + data = json.load(file) + if not isinstance(data, list): + raise ValueError( + f'Expected a list in {path}, got {type(data).__name__}' + ) + events = [] + for item in data: + event = event_from_dict(item) + # cannot add an event with _id to event stream + event._id = None # type: ignore[attr-defined] + events.append(event) + assert isinstance(events[0], MessageAction) + return events[1:], events[0] + except json.JSONDecodeError as e: + raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}') + + if __name__ == '__main__': args = parse_arguments() + config = setup_config_from_args(args) + # Determine the task + task_str = '' if args.file: task_str = read_task_from_file(args.file) elif args.task: task_str = args.task elif not sys.stdin.isatty(): task_str = read_task_from_stdin() + + initial_user_action: Action = NullAction() + if config.replay_trajectory_path: + if task_str: + raise ValueError( + 'User-specified task is not supported under trajectory replay mode' + ) + elif task_str: + initial_user_action = MessageAction(content=task_str) else: raise ValueError('No task provided. Please specify a task through -t, -f.') - initial_user_action: MessageAction = MessageAction(content=task_str) - - config = setup_config_from_args(args) # Set session name session_name = args.name diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 50fd73a10e1a..007998782b56 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -1,7 +1,5 @@ import hashlib -import json import uuid -from pathlib import Path from typing import Tuple, Type import openhands.agenthub # noqa F401 (we import this to get the agents registered) @@ -14,7 +12,6 @@ from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.events.event import Event -from openhands.events.serialization import event_from_dict from openhands.llm.llm import LLM from openhands.runtime import get_runtime_cls from openhands.runtime.base import Runtime @@ -82,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent: def create_controller( - agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True + agent: Agent, + runtime: Runtime, + config: AppConfig, + headless_mode: bool = True, + replay_logs: list[Event] | None = None, ) -> Tuple[AgentController, State | None]: event_stream = runtime.event_stream initial_state = None @@ -96,10 +97,6 @@ def create_controller( except Exception as e: logger.debug(f'Cannot restore agent state: {e}') - replay_logs: list[Event] | None = None - if config.replay_trajectory_path: - replay_logs = load_replay_log(config.replay_trajectory_path) - controller = AgentController( agent=agent, max_iterations=config.max_iterations, @@ -114,35 +111,6 @@ def create_controller( return (controller, initial_state) -def load_replay_log(trajectory_path: str) -> list[Event] | None: - try: - path = Path(trajectory_path).resolve() - - if not path.exists(): - logger.error(f'Trajectory file not found: {path}') - return None - - if not path.is_file(): - logger.error(f'Trajectory path is a directory, not a file: {path}') - return None - - with open(path, 'r', encoding='utf-8') as file: - data = json.load(file) - if not isinstance(data, list): - logger.error(f'Expected a list in {path}, got {type(data).__name__}') - return None - return [event_from_dict(item) for item in data] - - except json.JSONDecodeError as e: - logger.error(f'Invalid JSON format in {trajectory_path}: {e}') - except ValueError as e: - logger.error(f'Invalid Event in {trajectory_path}: {e}') - except Exception as e: - logger.error(f'Unexpected error loading {trajectory_path}: {e}') - - return None - - def generate_sid(config: AppConfig, session_name: str | None = None) -> str: """Generate a session id based on the session name and the jwt secret.""" session_name = session_name or str(uuid.uuid4()) diff --git a/openhands/events/event.py b/openhands/events/event.py index 6c7a2d8a3ac1..1859b8ee80da 100644 --- a/openhands/events/event.py +++ b/openhands/events/event.py @@ -24,6 +24,8 @@ class FileReadSource(str, Enum): @dataclass class Event: + INVALID_ID = -1 + @property def message(self) -> str | None: if hasattr(self, '_message'): @@ -34,7 +36,7 @@ def message(self) -> str | None: def id(self) -> int: if hasattr(self, '_id'): return self._id # type: ignore[attr-defined] - return -1 + return Event.INVALID_ID @property def timestamp(self): diff --git a/openhands/events/observation/browse.py b/openhands/events/observation/browse.py index 1052aaf17a91..c0bf2e643db6 100644 --- a/openhands/events/observation/browse.py +++ b/openhands/events/observation/browse.py @@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation): url: str trigger_by_action: str - screenshot: str = field(repr=False) # don't show in repr + screenshot: str = field(repr=False, default='') # don't show in repr error: bool = False observation: str = ObservationType.BROWSE # do not include in the memory