From eed7e2dd6e4b94a3e69e3f287068090978198ac4 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Wed, 19 Feb 2025 22:10:14 +0100 Subject: [PATCH] Refactor I/O utils; allow 'task' command line parameter in cli.py (#6187) Co-authored-by: OpenHands Bot --- openhands/agenthub/micro/agent.py | 2 +- openhands/core/cli.py | 36 +++++++++---------- openhands/core/main.py | 52 ++++++---------------------- openhands/core/utils/__init__.py | 0 openhands/events/stream.py | 2 +- openhands/io/__init__.py | 10 ++++++ openhands/io/io.py | 40 +++++++++++++++++++++ openhands/{core/utils => io}/json.py | 0 openhands/llm/llm.py | 4 +-- tests/unit/test_cli.py | 6 ++-- tests/unit/test_json.py | 2 +- tests/unit/test_json_encoder.py | 2 +- tests/unit/test_response_parsing.py | 2 +- 13 files changed, 88 insertions(+), 70 deletions(-) delete mode 100644 openhands/core/utils/__init__.py create mode 100644 openhands/io/__init__.py create mode 100644 openhands/io/io.py rename openhands/{core/utils => io}/json.py (100%) diff --git a/openhands/agenthub/micro/agent.py b/openhands/agenthub/micro/agent.py index 2c22e3840a51..37de035c461d 100644 --- a/openhands/agenthub/micro/agent.py +++ b/openhands/agenthub/micro/agent.py @@ -6,11 +6,11 @@ from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.core.message import ImageContent, Message, TextContent -from openhands.core.utils import json from openhands.events.action import Action from openhands.events.event import Event from openhands.events.serialization.action import action_from_dict from openhands.events.serialization.event import event_to_memory +from openhands.io import json from openhands.llm.llm import LLM diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 34186cd5fc8a..05a390d8b815 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -30,6 +30,7 @@ CmdOutputObservation, FileEditObservation, ) +from openhands.io import read_input, read_task def display_message(message: str): @@ -82,21 +83,6 @@ def display_event(event: Event, config: AppConfig): display_confirmation(event.confirmation_state) -def read_input(config: AppConfig) -> str: - """Read input from user based on config settings.""" - if config.cli_multiline_input: - print('Enter your message (enter "/exit" on a new line to finish):') - lines = [] - while True: - line = input('>> ').rstrip() - if line == '/exit': # finish input - break - lines.append(line) - return '\n'.join(lines) - else: - return input('>> ').rstrip() - - async def main(loop: asyncio.AbstractEventLoop): """Runs the agent in CLI mode.""" @@ -104,7 +90,14 @@ async def main(loop: asyncio.AbstractEventLoop): logger.setLevel(logging.WARNING) - config = setup_config_from_args(args) + # Load config from toml and override with command line arguments + config: AppConfig = setup_config_from_args(args) + + # Read task from file, CLI args, or stdin + task_str = read_task(args, config.cli_multiline_input) + + # If we have a task, create initial user action + initial_user_action = MessageAction(content=task_str) if task_str else None sid = str(uuid4()) @@ -117,7 +110,9 @@ async def main(loop: asyncio.AbstractEventLoop): async def prompt_for_next_task(): # Run input() in a thread pool to avoid blocking the event loop - next_message = await loop.run_in_executor(None, read_input, config) + next_message = await loop.run_in_executor( + None, read_input, config.cli_multiline_input + ) if not next_message.strip(): await prompt_for_next_task() if next_message == 'exit': @@ -162,7 +157,12 @@ def on_event(event: Event) -> None: await runtime.connect() - asyncio.create_task(prompt_for_next_task()) + if initial_user_action: + # If there's an initial user action, enqueue it and do not prompt again + event_stream.add_event(initial_user_action, EventSource.USER) + else: + # Otherwise prompt for the user's first message right away + asyncio.create_task(prompt_for_next_task()) await run_agent_until_done( controller, runtime, [AgentState.STOPPED, AgentState.ERROR] diff --git a/openhands/core/main.py b/openhands/core/main.py index 2652931cce7a..12e0c4e7876c 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -1,7 +1,6 @@ import asyncio import json import os -import sys from pathlib import Path from typing import Callable, Protocol @@ -29,6 +28,7 @@ from openhands.events.observation import AgentStateChangedObservation from openhands.events.serialization import event_from_dict from openhands.events.serialization.event import event_to_trajectory +from openhands.io import read_input, read_task from openhands.runtime.base import Runtime @@ -41,32 +41,6 @@ def __call__( ) -> str: ... -def read_task_from_file(file_path: str) -> str: - """Read task from the specified file.""" - with open(file_path, 'r', encoding='utf-8') as file: - return file.read() - - -def read_task_from_stdin() -> str: - """Read task from stdin.""" - return sys.stdin.read() - - -def read_input(config: AppConfig) -> str: - """Read input from user based on config settings.""" - if config.cli_multiline_input: - print('Enter your message (enter "/exit" on a new line to finish):') - lines = [] - while True: - line = input('>> ').rstrip() - if line == '/exit': # finish input - break - lines.append(line) - return '\n'.join(lines) - else: - return input('>> ').rstrip() - - async def run_controller( config: AppConfig, initial_user_action: Action, @@ -139,7 +113,6 @@ async def run_controller( assert isinstance( initial_user_action, Action ), f'initial user actions must be an Action, got {type(initial_user_action)}' - # Logging logger.debug( f'Agent Controller Initialized: Running agent {agent.name}, model ' f'{agent.llm.config.model}, with actions: {initial_user_action}' @@ -167,7 +140,7 @@ def on_event(event: Event): if exit_on_message: message = '/exit' elif fake_user_response_fn is None: - message = read_input(config) + message = read_input(config.cli_multiline_input) else: message = fake_user_response_fn(controller.get_state()) action = MessageAction(content=message) @@ -268,28 +241,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]: if __name__ == '__main__': args = parse_arguments() - config = setup_config_from_args(args) + config: AppConfig = setup_config_from_args(args) - # Determine the task - task_str = '' - if args.file: - task_str = read_task_from_file(args.file) - elif args.task: - task_str = args.task - elif not sys.stdin.isatty(): - task_str = read_task_from_stdin() + # Read task from file, CLI args, or stdin + task_str = read_task(args, config.cli_multiline_input) - initial_user_action: Action = NullAction() if config.replay_trajectory_path: if task_str: raise ValueError( 'User-specified task is not supported under trajectory replay mode' ) - elif task_str: - initial_user_action = MessageAction(content=task_str) - else: + + if not task_str: raise ValueError('No task provided. Please specify a task through -t, -f.') + # Create initial user action + initial_user_action: MessageAction = MessageAction(content=task_str) + # Set session name session_name = args.name sid = generate_sid(config, session_name) diff --git a/openhands/core/utils/__init__.py b/openhands/core/utils/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 0fc547803f6d..938269822a7a 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -8,9 +8,9 @@ from typing import Callable, Iterable from openhands.core.logger import openhands_logger as logger -from openhands.core.utils import json from openhands.events.event import Event, EventSource from openhands.events.serialization.event import event_from_dict, event_to_dict +from openhands.io import json from openhands.storage import FileStore from openhands.storage.locations import ( get_conversation_dir, diff --git a/openhands/io/__init__.py b/openhands/io/__init__.py new file mode 100644 index 000000000000..bf1a054356c1 --- /dev/null +++ b/openhands/io/__init__.py @@ -0,0 +1,10 @@ +from openhands.io.io import read_input, read_task, read_task_from_file +from openhands.io.json import dumps, loads + +__all__ = [ + 'read_input', + 'read_task_from_file', + 'read_task', + 'dumps', + 'loads', +] diff --git a/openhands/io/io.py b/openhands/io/io.py new file mode 100644 index 000000000000..2e42df912b77 --- /dev/null +++ b/openhands/io/io.py @@ -0,0 +1,40 @@ +import argparse +import sys + + +def read_input(cli_multiline_input: bool = False) -> str: + """Read input from user based on config settings.""" + if cli_multiline_input: + print('Enter your message (enter "/exit" on a new line to finish):') + lines = [] + while True: + line = input('>> ').rstrip() + if line == '/exit': # finish input + break + lines.append(line) + return '\n'.join(lines) + else: + return input('>> ').rstrip() + + +def read_task_from_file(file_path: str) -> str: + """Read task from the specified file.""" + with open(file_path, 'r', encoding='utf-8') as file: + return file.read() + + +def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str: + """ + Read the task from the CLI args, file, or stdin. + """ + + # Determine the task + task_str = '' + if args.file: + task_str = read_task_from_file(args.file) + elif args.task: + task_str = args.task + elif not sys.stdin.isatty(): + task_str = read_input(cli_multiline_input) + + return task_str diff --git a/openhands/core/utils/json.py b/openhands/io/json.py similarity index 100% rename from openhands/core/utils/json.py rename to openhands/io/json.py diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index a9071b43bed3..b40f11ca8396 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -172,7 +172,7 @@ def __init__( ) def wrapper(*args, **kwargs): """Wrapper for the litellm completion function. Logs the input and output of the completion function.""" - from openhands.core.utils import json + from openhands.io import json messages: list[dict[str, Any]] | dict[str, Any] = [] mock_function_calling = not self.is_function_calling_active() @@ -369,7 +369,7 @@ def init_model_info(self): # noinspection PyBroadException except Exception: pass - from openhands.core.utils import json + from openhands.io import json logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}') diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 520d85d2aa7d..3931f2fdd713 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -1,7 +1,7 @@ from unittest.mock import patch -from openhands.core.cli import read_input from openhands.core.config import AppConfig +from openhands.io import read_input def test_single_line_input(): @@ -10,7 +10,7 @@ def test_single_line_input(): config.cli_multiline_input = False with patch('builtins.input', return_value='hello world'): - result = read_input(config) + result = read_input(config.cli_multiline_input) assert result == 'hello world' @@ -23,5 +23,5 @@ def test_multiline_input(): mock_inputs = ['line 1', 'line 2', 'line 3', '/exit'] with patch('builtins.input', side_effect=mock_inputs): - result = read_input(config) + result = read_input(config.cli_multiline_input) assert result == 'line 1\nline 2\nline 3' diff --git a/tests/unit/test_json.py b/tests/unit/test_json.py index 883efdfe4cfb..85ab265a536d 100644 --- a/tests/unit/test_json.py +++ b/tests/unit/test_json.py @@ -1,7 +1,7 @@ from datetime import datetime -from openhands.core.utils import json from openhands.events.action import MessageAction +from openhands.io import json def test_event_serialization_deserialization(): diff --git a/tests/unit/test_json_encoder.py b/tests/unit/test_json_encoder.py index daa2708a6256..10058c8c2ba3 100644 --- a/tests/unit/test_json_encoder.py +++ b/tests/unit/test_json_encoder.py @@ -3,7 +3,7 @@ import psutil -from openhands.core.utils.json import dumps +from openhands.io.json import dumps def get_memory_usage(): diff --git a/tests/unit/test_response_parsing.py b/tests/unit/test_response_parsing.py index fd588d4c6edf..dc51dee3abe4 100644 --- a/tests/unit/test_response_parsing.py +++ b/tests/unit/test_response_parsing.py @@ -2,11 +2,11 @@ from openhands.agenthub.micro.agent import parse_response as parse_response_micro from openhands.core.exceptions import LLMResponseError -from openhands.core.utils.json import loads as custom_loads from openhands.events.action import ( FileWriteAction, MessageAction, ) +from openhands.io import loads as custom_loads @pytest.mark.parametrize(