Skip to content

Commit

Permalink
Fix Docker runtimes not stopping (All-Hands-AI#6470)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
2 people authored and Kevin Chen committed Feb 4, 2025
1 parent f81e01a commit 6778ec5
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 16 deletions.
20 changes: 8 additions & 12 deletions openhands/runtime/impl/docker/docker_runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import atexit
from functools import lru_cache
from typing import Callable
from uuid import UUID

import docker
import requests
Expand All @@ -26,6 +26,7 @@
from openhands.runtime.utils.log_streamer import LogStreamer
from openhands.runtime.utils.runtime_build import build_runtime_image
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.shutdown_listener import add_shutdown_listener
from openhands.utils.tenacity_stop import stop_if_should_exit

CONTAINER_NAME_PREFIX = 'openhands-runtime-'
Expand All @@ -36,13 +37,6 @@
APP_PORT_RANGE_2 = (55000, 59999)


def stop_all_runtime_containers():
stop_all_containers(CONTAINER_NAME_PREFIX)


_atexit_registered = False


class DockerRuntime(ActionExecutionClient):
"""This runtime will subscribe the event stream.
When receive an event, it will send the event to runtime-client which run inside the docker environment.
Expand All @@ -55,6 +49,8 @@ class DockerRuntime(ActionExecutionClient):
env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None.
"""

_shutdown_listener_id: UUID | None = None

def __init__(
self,
config: AppConfig,
Expand All @@ -66,10 +62,10 @@ def __init__(
attach_to_existing: bool = False,
headless_mode: bool = True,
):
global _atexit_registered
if not _atexit_registered:
_atexit_registered = True
atexit.register(stop_all_runtime_containers)
if not DockerRuntime._shutdown_listener_id:
DockerRuntime._shutdown_listener_id = add_shutdown_listener(
lambda: stop_all_containers(CONTAINER_NAME_PREFIX)
)

self.config = config
self._runtime_initialized: bool = False
Expand Down
29 changes: 25 additions & 4 deletions openhands/utils/shutdown_listener.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
"""
This module monitors the app for shutdown signals
This module monitors the app for shutdown signals. This exists because the atexit module
does not play nocely with stareltte / uvicorn shutdown signals.
"""

import asyncio
import signal
import threading
import time
from types import FrameType
from typing import Callable
from uuid import UUID, uuid4

from uvicorn.server import HANDLED_SIGNALS

from openhands.core.logger import openhands_logger as logger

_should_exit = None
_shutdown_listeners: dict[UUID, Callable] = {}


def _register_signal_handler(sig: signal.Signals):
Expand All @@ -21,9 +25,16 @@ def _register_signal_handler(sig: signal.Signals):
def handler(sig_: int, frame: FrameType | None):
logger.debug(f'shutdown_signal:{sig_}')
global _should_exit
_should_exit = True
if original_handler:
original_handler(sig_, frame) # type: ignore[unreachable]
if not _should_exit:
_should_exit = True
listeners = list(_shutdown_listeners.values())
for callable in listeners:
try:
callable()
except Exception:
logger.exception('Error calling shutdown listener')
if original_handler:
original_handler(sig_, frame) # type: ignore[unreachable]

original_handler = signal.signal(sig, handler)

Expand Down Expand Up @@ -71,3 +82,13 @@ async def async_sleep_if_should_continue(timeout: float):
start_time = time.time()
while time.time() - start_time < timeout and should_continue():
await asyncio.sleep(1)


def add_shutdown_listener(callable: Callable) -> UUID:
id_ = uuid4()
_shutdown_listeners[id_] = callable
return id_


def remove_shutdown_listener(id_: UUID) -> bool:
return _shutdown_listeners.pop(id_, None) is not None
116 changes: 116 additions & 0 deletions tests/unit/test_shutdown_listener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import signal
from dataclasses import dataclass, field
from signal import Signals
from typing import Callable
from unittest.mock import MagicMock, patch
from uuid import UUID

import pytest

from openhands.utils import shutdown_listener
from openhands.utils.shutdown_listener import (
add_shutdown_listener,
remove_shutdown_listener,
should_continue,
)


@pytest.fixture(autouse=True)
def cleanup_listeners():
shutdown_listener._shutdown_listeners.clear()
shutdown_listener._should_exit = False


@dataclass
class MockSignal:
handlers: dict[Signals, Callable] = field(default_factory=dict)

def signal(self, signalnum: Signals, handler: Callable):
result = self.handlers.get(signalnum)
self.handlers[signalnum] = handler
return result

def trigger(self, signalnum: Signals):
handler = self.handlers.get(signalnum)
if handler:
handler(signalnum.value, None)


def test_add_shutdown_listener():
mock_callable = MagicMock()
listener_id = add_shutdown_listener(mock_callable)

assert isinstance(listener_id, UUID)
assert listener_id in shutdown_listener._shutdown_listeners
assert shutdown_listener._shutdown_listeners[listener_id] == mock_callable


def test_remove_shutdown_listener():
mock_callable = MagicMock()
listener_id = add_shutdown_listener(mock_callable)

# Test successful removal
assert remove_shutdown_listener(listener_id) is True
assert listener_id not in shutdown_listener._shutdown_listeners

# Test removing non-existent listener
assert remove_shutdown_listener(listener_id) is False


def test_signal_handler_calls_listeners():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable1 = MagicMock()
mock_callable2 = MagicMock()
add_shutdown_listener(mock_callable1)
add_shutdown_listener(mock_callable2)

# Register and trigger signal handler
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)

# Verify both listeners were called
mock_callable1.assert_called_once()
mock_callable2.assert_called_once()

# Verify should_continue returns False after shutdown
assert should_continue() is False


def test_listeners_called_only_once():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable = MagicMock()
add_shutdown_listener(mock_callable)

# Register and trigger signal handler multiple times
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)

# Verify listener was called only once
assert mock_callable.call_count == 1


def test_remove_listener_during_shutdown():
mock_signal = MockSignal()
with patch('openhands.utils.shutdown_listener.signal', mock_signal):
mock_callable1 = MagicMock()
mock_callable2 = MagicMock()

# Second listener removes the first listener when called
listener1_id = add_shutdown_listener(mock_callable1)

def remove_other_listener():
remove_shutdown_listener(listener1_id)
mock_callable2()

add_shutdown_listener(remove_other_listener)

# Register and trigger signal handler
shutdown_listener._register_signal_handler(signal.SIGTERM)
mock_signal.trigger(signal.SIGTERM)

# Both listeners should still be called
assert mock_callable1.call_count == 1
assert mock_callable2.call_count == 1

0 comments on commit 6778ec5

Please sign in to comment.