Skip to content

Commit

Permalink
Fix server lock up on session init (All-Hands-AI#4007)
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr authored Sep 24, 2024
1 parent 1b1d8f0 commit ee284ba
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
7 changes: 2 additions & 5 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AgentController:
confirmation_mode: bool
agent_to_llm_config: dict[str, LLMConfig]
agent_configs: dict[str, AgentConfig]
agent_task: asyncio.Task | None = None
agent_task: asyncio.Future | None = None
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
Expand Down Expand Up @@ -115,9 +115,6 @@ def __init__(
# stuck helper
self._stuck_detector = StuckDetector(self.state)

if not is_delegate:
self.agent_task = asyncio.create_task(self._start_step_loop())

async def close(self):
"""Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream."""
if self.agent_task is not None:
Expand Down Expand Up @@ -149,7 +146,7 @@ async def report_error(self, message: str, exception: Exception | None = None):
self.state.last_error += f': {exception}'
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)

async def _start_step_loop(self):
async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""

logger.info(f'[Agent Controller {self.id}] Starting step loop...')
Expand Down
3 changes: 3 additions & 0 deletions openhands/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ async def main():
event_stream=event_stream,
)

if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())

async def prompt_for_next_task():
next_message = input('How can I help? >> ')
if next_message == 'exit':
Expand Down
3 changes: 3 additions & 0 deletions openhands/core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ async def run_controller(
headless_mode=headless_mode,
)

if controller is not None:
controller.agent_task = asyncio.create_task(controller.start_step_loop())

assert isinstance(task_str, str), f'task_str must be a string, got {type(task_str)}'
# Logging
logger.info(
Expand Down
33 changes: 25 additions & 8 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio

from threading import Thread
from typing import Callable, Optional

from openhands.controller import AgentController
Expand Down Expand Up @@ -65,16 +67,28 @@ async def start(
raise RuntimeError(
'Session already started. You need to close this session and start a new one.'
)
await self._create_security_analyzer(config.security.security_analyzer)
await self._create_runtime(runtime_name, config, agent, status_message_callback)
await self._create_controller(

self.loop = asyncio.new_event_loop()
self.thread = Thread(target=self._run, daemon=True)
self.thread.start()

self._create_security_analyzer(config.security.security_analyzer)
self._create_runtime(runtime_name, config, agent, status_message_callback)
self._create_controller(
agent,
config.security.confirmation_mode,
max_iterations,
max_budget_per_task=max_budget_per_task,
agent_to_llm_config=agent_to_llm_config,
agent_configs=agent_configs,
)

if self.controller is not None:
self.controller.agent_task = asyncio.run_coroutine_threadsafe(self.controller.start_step_loop(), self.loop) # type: ignore

def _run(self):
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

async def close(self):
"""Closes the Agent session"""
Expand All @@ -89,9 +103,13 @@ async def close(self):
self.runtime.close()
if self.security_analyzer is not None:
await self.security_analyzer.close()

self.loop.call_soon_threadsafe(self.loop.stop)
self.thread.join()

self._closed = True

async def _create_security_analyzer(self, security_analyzer: str | None):
def _create_security_analyzer(self, security_analyzer: str | None):
"""Creates a SecurityAnalyzer instance that will be used to analyze the agent actions
Parameters:
Expand All @@ -104,7 +122,7 @@ async def _create_security_analyzer(self, security_analyzer: str | None):
security_analyzer, SecurityAnalyzer
)(self.event_stream)

async def _create_runtime(
def _create_runtime(
self,
runtime_name: str,
config: AppConfig,
Expand All @@ -125,8 +143,7 @@ async def _create_runtime(
logger.info(f'Initializing runtime `{runtime_name}` now...')
runtime_cls = get_runtime_cls(runtime_name)

self.runtime = await asyncio.to_thread(
runtime_cls,
self.runtime = runtime_cls(
config=config,
event_stream=self.event_stream,
sid=self.sid,
Expand All @@ -141,7 +158,7 @@ async def _create_runtime(
else:
logger.warning('Runtime initialization failed')

async def _create_controller(
def _create_controller(
self,
agent: Agent,
confirmation_mode: bool,
Expand Down

0 comments on commit ee284ba

Please sign in to comment.