Skip to content

Commit

Permalink
Refactor I/O utils; allow 'task' command line parameter in cli.py (#6187
Browse files Browse the repository at this point in the history
)

Co-authored-by: OpenHands Bot <[email protected]>
  • Loading branch information
enyst and openhands-agent authored Feb 19, 2025
1 parent 663e361 commit eed7e2d
Show file tree
Hide file tree
Showing 13 changed files with 88 additions and 70 deletions.
2 changes: 1 addition & 1 deletion openhands/agenthub/micro/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.message import ImageContent, Message, TextContent
from openhands.core.utils import json
from openhands.events.action import Action
from openhands.events.event import Event
from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.event import event_to_memory
from openhands.io import json
from openhands.llm.llm import LLM


Expand Down
36 changes: 18 additions & 18 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
CmdOutputObservation,
FileEditObservation,
)
from openhands.io import read_input, read_task


def display_message(message: str):
Expand Down Expand Up @@ -82,29 +83,21 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode."""

args = parse_arguments()

logger.setLevel(logging.WARNING)

config = setup_config_from_args(args)
# Load config from toml and override with command line arguments
config: AppConfig = setup_config_from_args(args)

# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

# If we have a task, create initial user action
initial_user_action = MessageAction(content=task_str) if task_str else None

sid = str(uuid4())

Expand All @@ -117,7 +110,9 @@ async def main(loop: asyncio.AbstractEventLoop):

async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(None, read_input, config)
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
Expand Down Expand Up @@ -162,7 +157,12 @@ def on_event(event: Event) -> None:

await runtime.connect()

asyncio.create_task(prompt_for_next_task())
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task())

await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
Expand Down
52 changes: 10 additions & 42 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol

Expand Down Expand Up @@ -29,6 +28,7 @@
from openhands.events.observation import AgentStateChangedObservation
from openhands.events.serialization import event_from_dict
from openhands.events.serialization.event import event_to_trajectory
from openhands.io import read_input, read_task
from openhands.runtime.base import Runtime


Expand All @@ -41,32 +41,6 @@ def __call__(
) -> str: ...


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def run_controller(
config: AppConfig,
initial_user_action: Action,
Expand Down Expand Up @@ -139,7 +113,6 @@ async def run_controller(
assert isinstance(
initial_user_action, Action
), f'initial user actions must be an Action, got {type(initial_user_action)}'
# Logging
logger.debug(
f'Agent Controller Initialized: Running agent {agent.name}, model '
f'{agent.llm.config.model}, with actions: {initial_user_action}'
Expand Down Expand Up @@ -167,7 +140,7 @@ def on_event(event: Event):
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = read_input(config)
message = read_input(config.cli_multiline_input)
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
Expand Down Expand Up @@ -268,28 +241,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
if __name__ == '__main__':
args = parse_arguments()

config = setup_config_from_args(args)
config: AppConfig = setup_config_from_args(args)

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:

if not task_str:
raise ValueError('No task provided. Please specify a task through -t, -f.')

# Create initial user action
initial_user_action: MessageAction = MessageAction(content=task_str)

# Set session name
session_name = args.name
sid = generate_sid(config, session_name)
Expand Down
Empty file removed openhands/core/utils/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from typing import Callable, Iterable

from openhands.core.logger import openhands_logger as logger
from openhands.core.utils import json
from openhands.events.event import Event, EventSource
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.io import json
from openhands.storage import FileStore
from openhands.storage.locations import (
get_conversation_dir,
Expand Down
10 changes: 10 additions & 0 deletions openhands/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from openhands.io.io import read_input, read_task, read_task_from_file
from openhands.io.json import dumps, loads

__all__ = [
'read_input',
'read_task_from_file',
'read_task',
'dumps',
'loads',
]
40 changes: 40 additions & 0 deletions openhands/io/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import sys


def read_input(cli_multiline_input: bool = False) -> str:
"""Read input from user based on config settings."""
if cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
"""
Read the task from the CLI args, file, or stdin.
"""

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_input(cli_multiline_input)

return task_str
File renamed without changes.
4 changes: 2 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(
)
def wrapper(*args, **kwargs):
"""Wrapper for the litellm completion function. Logs the input and output of the completion function."""
from openhands.core.utils import json
from openhands.io import json

messages: list[dict[str, Any]] | dict[str, Any] = []
mock_function_calling = not self.is_function_calling_active()
Expand Down Expand Up @@ -369,7 +369,7 @@ def init_model_info(self):
# noinspection PyBroadException
except Exception:
pass
from openhands.core.utils import json
from openhands.io import json

logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')

Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest.mock import patch

from openhands.core.cli import read_input
from openhands.core.config import AppConfig
from openhands.io import read_input


def test_single_line_input():
Expand All @@ -10,7 +10,7 @@ def test_single_line_input():
config.cli_multiline_input = False

with patch('builtins.input', return_value='hello world'):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'hello world'


Expand All @@ -23,5 +23,5 @@ def test_multiline_input():
mock_inputs = ['line 1', 'line 2', 'line 3', '/exit']

with patch('builtins.input', side_effect=mock_inputs):
result = read_input(config)
result = read_input(config.cli_multiline_input)
assert result == 'line 1\nline 2\nline 3'
2 changes: 1 addition & 1 deletion tests/unit/test_json.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime

from openhands.core.utils import json
from openhands.events.action import MessageAction
from openhands.io import json


def test_event_serialization_deserialization():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import psutil

from openhands.core.utils.json import dumps
from openhands.io.json import dumps


def get_memory_usage():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_response_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from openhands.agenthub.micro.agent import parse_response as parse_response_micro
from openhands.core.exceptions import LLMResponseError
from openhands.core.utils.json import loads as custom_loads
from openhands.events.action import (
FileWriteAction,
MessageAction,
)
from openhands.io import loads as custom_loads


@pytest.mark.parametrize(
Expand Down

0 comments on commit eed7e2d

Please sign in to comment.