Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor I/O utils; allow 'task' command line parameter in cli.py #6187

Merged
merged 24 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openhands.core.loop import run_agent_until_done
from openhands.core.schema import AgentState
from openhands.core.setup import create_agent, create_controller, create_runtime
from openhands.core.utils.io import read_input, read_task
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import (
Action,
Expand Down Expand Up @@ -83,29 +84,21 @@ def display_event(event: Event, config: AppConfig):
display_confirmation(event.confirmation_state)


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def main(loop: asyncio.AbstractEventLoop):
"""Runs the agent in CLI mode"""

args = parse_arguments()

logger.setLevel(logging.WARNING)

config = setup_config_from_args(args)
# Load config from toml and override with command line arguments
config: AppConfig = setup_config_from_args(args)

# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

# If we have a task, create initial user action
initial_user_action = MessageAction(content=task_str) if task_str else None

sid = str(uuid4())

Expand All @@ -118,7 +111,9 @@ async def main(loop: asyncio.AbstractEventLoop):

async def prompt_for_next_task():
# Run input() in a thread pool to avoid blocking the event loop
next_message = await loop.run_in_executor(None, read_input, config)
next_message = await loop.run_in_executor(
None, read_input, config.cli_multiline_input
)
if not next_message.strip():
await prompt_for_next_task()
if next_message == 'exit':
Expand Down Expand Up @@ -164,7 +159,12 @@ def on_event(event: Event) -> None:

await runtime.connect()

asyncio.create_task(prompt_for_next_task())
if initial_user_action:
# If there's an initial user action, enqueue it and do not prompt again
event_stream.add_event(initial_user_action, EventSource.USER)
else:
# Otherwise prompt for the user's first message right away
asyncio.create_task(prompt_for_next_task())

await run_agent_until_done(
controller, runtime, [AgentState.STOPPED, AgentState.ERROR]
Expand Down
51 changes: 10 additions & 41 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import os
import sys
from pathlib import Path
from typing import Callable, Protocol

Expand All @@ -22,6 +21,7 @@
create_runtime,
generate_sid,
)
from openhands.core.utils.io import read_input, read_task
from openhands.events import EventSource, EventStreamSubscriber
from openhands.events.action import MessageAction, NullAction
from openhands.events.action.action import Action
Expand All @@ -41,32 +41,6 @@ def __call__(
) -> str: ...


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task_from_stdin() -> str:
"""Read task from stdin."""
return sys.stdin.read()


def read_input(config: AppConfig) -> str:
"""Read input from user based on config settings."""
if config.cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


async def run_controller(
config: AppConfig,
initial_user_action: Action,
Expand Down Expand Up @@ -146,7 +120,7 @@ def on_event(event: Event):
if exit_on_message:
message = '/exit'
elif fake_user_response_fn is None:
message = read_input(config)
message = read_input(config.cli_multiline_input)
else:
message = fake_user_response_fn(controller.get_state())
action = MessageAction(content=message)
Expand Down Expand Up @@ -243,28 +217,23 @@ def load_replay_log(trajectory_path: str) -> tuple[list[Event] | None, Action]:
if __name__ == '__main__':
args = parse_arguments()

config = setup_config_from_args(args)
config: AppConfig = setup_config_from_args(args)

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_task_from_stdin()
# Read task from file, CLI args, or stdin
task_str = read_task(args, config.cli_multiline_input)

initial_user_action: Action = NullAction()
if config.replay_trajectory_path:
if task_str:
raise ValueError(
'User-specified task is not supported under trajectory replay mode'
)
elif task_str:
initial_user_action = MessageAction(content=task_str)
else:

if not task_str:
raise ValueError('No task provided. Please specify a task through -t, -f.')

# Create initial user action
initial_user_action: MessageAction = MessageAction(content=task_str)

# Set session name
session_name = args.name
sid = generate_sid(config, session_name)
Expand Down
40 changes: 40 additions & 0 deletions openhands/core/utils/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import argparse
import sys


def read_input(cli_multiline_input: bool = False) -> str:
"""Read input from user based on config settings."""
if cli_multiline_input:
print('Enter your message (enter "/exit" on a new line to finish):')
lines = []
while True:
line = input('>> ').rstrip()
if line == '/exit': # finish input
break
lines.append(line)
return '\n'.join(lines)
else:
return input('>> ').rstrip()


def read_task_from_file(file_path: str) -> str:
"""Read task from the specified file."""
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()


def read_task(args: argparse.Namespace, cli_multiline_input: bool) -> str:
"""
Read the task from the CLI args, file, or stdin.
"""

# Determine the task
task_str = ''
if args.file:
task_str = read_task_from_file(args.file)
elif args.task:
task_str = args.task
elif not sys.stdin.isatty():
task_str = read_input(cli_multiline_input)

return task_str
Loading