Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
li-boxuan committed Jan 13, 2025
1 parent 113f88f commit bd272be
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 49 deletions.
18 changes: 13 additions & 5 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,13 +244,20 @@ def should_replay(self) -> bool:
if not self.replay_logs:
# not in replay mode
return False
while self.replay_index < len(self.replay_logs) and not isinstance(
self.replay_logs[self.replay_index], Action

def replayable(index: int) -> bool:
return (
self.replay_logs is not None
and index < len(self.replay_logs)
and isinstance(self.replay_logs[index], Action)
and self.replay_logs[index].source != EventSource.USER
)

while self.replay_index < len(self.replay_logs) and not replayable(
self.replay_index
):
self.replay_index += 1
return self.replay_index < len(self.replay_logs) and isinstance(
self.replay_logs[self.replay_index], Action
)
return replayable(self.replay_index)

def should_step(self, event: Event) -> bool:
"""
Expand Down Expand Up @@ -591,6 +598,7 @@ async def _step(self) -> None:
event = self.replay_logs[self.replay_index]
assert isinstance(event, Action)
action = event
self.replay_index += 1
else:
try:
action = self.agent.step(self.state)
Expand Down
65 changes: 60 additions & 5 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
Expand All @@ -22,10 +23,11 @@
generate_sid,
)
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import MessageAction
from openhands.events.action import MessageAction, NullAction
from openhands.events.action.action import Action
from openhands.events.event import Event
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.runtime.base import Runtime

Expand Down Expand Up @@ -101,7 +103,17 @@ async def run_controller(
if agent is None:
agent = create_agent(runtime, config)

controller, initial_state = create_controller(agent, runtime, config)
replay_logs: list[Event] | None = None
if config.replay_trajectory_path:
logger.info('Trajectory replay is enabled')
assert isinstance(initial_user_action, NullAction)
replay_logs, initial_user_action = load_replay_log(
config.replay_trajectory_path
)

controller, initial_state = create_controller(
agent, runtime, config, replay_logs=replay_logs
)

assert isinstance(
initial_user_action, Action
Expand Down Expand Up @@ -194,21 +206,64 @@ def auto_continue_response(
return message


def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
"""
Load trajectory from given path, serialize it to a list of events, and return
two things:
1) A list of events except the first action
2) First action (user message, a.k.a. initial task)
"""
try:
path = Path(trajectory_path).resolve()

if not path.exists():
raise ValueError(f'Trajectory file not found: {path}')

if not path.is_file():
raise ValueError(f'Trajectory path is a directory, not a file: {path}')

with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
if not isinstance(data, list):
raise ValueError(
f'Expected a list in {path}, got {type(data).__name__}'
)
events = []
for item in data:
event = event_from_dict(item)
# cannot add an event with _id to event stream
event._id = None # type: ignore[attr-defined]
events.append(event)
assert isinstance(events[0], MessageAction)
return events[1:], events[0]
except json.JSONDecodeError as e:
raise ValueError(f'Invalid JSON format in {trajectory_path}: {e}')


if __name__ == '__main__':
args = parse_arguments()

config = setup_config_from_args(args)

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()

initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:
raise ValueError('No task provided. Please specify a task through -t, -f.')
initial_user_action: MessageAction = MessageAction(content=task_str)

config = setup_config_from_args(args)

# Set session name
session_name = args.name
Expand Down
42 changes: 5 additions & 37 deletions openhands/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import hashlib
import json
import uuid
from pathlib import Path
from typing import Tuple, Type

import openhands.agenthub # noqa F401 (we import this to get the agents registered)
Expand All @@ -14,7 +12,6 @@
from openhands.core.logger import openhands_logger as logger
from openhands.events import EventStream
from openhands.events.event import Event
from openhands.events.serialization import event_from_dict
from openhands.llm.llm import LLM
from openhands.runtime import get_runtime_cls
from openhands.runtime.base import Runtime
Expand Down Expand Up @@ -82,7 +79,11 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent:


def create_controller(
agent: Agent, runtime: Runtime, config: AppConfig, headless_mode: bool = True
agent: Agent,
runtime: Runtime,
config: AppConfig,
headless_mode: bool = True,
replay_logs: list[Event] | None = None,
) -> Tuple[AgentController, State | None]:
event_stream = runtime.event_stream
initial_state = None
Expand All @@ -96,10 +97,6 @@ def create_controller(
except Exception as e:
logger.debug(f'Cannot restore agent state: {e}')

replay_logs: list[Event] | None = None
if config.replay_trajectory_path:
replay_logs = load_replay_log(config.replay_trajectory_path)

controller = AgentController(
agent=agent,
max_iterations=config.max_iterations,
Expand All @@ -114,35 +111,6 @@ def create_controller(
return (controller, initial_state)


def load_replay_log(trajectory_path: str) -> list[Event] | None:
try:
path = Path(trajectory_path).resolve()

if not path.exists():
logger.error(f'Trajectory file not found: {path}')
return None

if not path.is_file():
logger.error(f'Trajectory path is a directory, not a file: {path}')
return None

with open(path, 'r', encoding='utf-8') as file:
data = json.load(file)
if not isinstance(data, list):
logger.error(f'Expected a list in {path}, got {type(data).__name__}')
return None
return [event_from_dict(item) for item in data]

except json.JSONDecodeError as e:
logger.error(f'Invalid JSON format in {trajectory_path}: {e}')
except ValueError as e:
logger.error(f'Invalid Event in {trajectory_path}: {e}')
except Exception as e:
logger.error(f'Unexpected error loading {trajectory_path}: {e}')

return None


def generate_sid(config: AppConfig, session_name: str | None = None) -> str:
"""Generate a session id based on the session name and the jwt secret."""
session_name = session_name or str(uuid.uuid4())
Expand Down
4 changes: 3 additions & 1 deletion openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class FileReadSource(str, Enum):

@dataclass
class Event:
INVALID_ID = -1

@property
def message(self) -> str | None:
if hasattr(self, '_message'):
Expand All @@ -34,7 +36,7 @@ def message(self) -> str | None:
def id(self) -> int:
if hasattr(self, '_id'):
return self._id # type: ignore[attr-defined]
return -1
return Event.INVALID_ID

@property
def timestamp(self):
Expand Down
2 changes: 1 addition & 1 deletion openhands/events/observation/browse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BrowserOutputObservation(Observation):

url: str
trigger_by_action: str
screenshot: str = field(repr=False) # don't show in repr
screenshot: str = field(repr=False, default='') # don't show in repr
error: bool = False
observation: str = ObservationType.BROWSE
# do not include in the memory
Expand Down

0 comments on commit bd272be

Please sign in to comment.