From 612c6eee82148ae966f757fc0a34ef3d31dced7d Mon Sep 17 00:00:00 2001 From: openhands Date: Sun, 23 Feb 2025 21:58:04 +0000 Subject: [PATCH] fix: Fix race conditions in EventStream subscriber management This commit fixes several issues with subscriber management in EventStream: 1. Adds proper thread synchronization to prevent race conditions 2. Changes cleanup order to prevent events after unsubscribe 3. Improves thread pool shutdown behavior 4. Adds snapshot mechanism for safe iteration 5. Adds better error handling and logging The changes ensure that: - Subscribers are properly unsubscribed and don't receive events after - Thread pools are shut down gracefully - Race conditions are handled correctly - Edge cases are handled properly Added tests: - test_subscriber_behavior: Basic subscriber functionality - test_subscriber_stress: Stress testing with rapid subscribe/unsubscribe - test_subscriber_unsubscribe_bug: Specific bug reproduction - test_subscriber_unsubscribe_concurrent: Concurrent operations - test_subscriber_race_condition: Race condition testing --- openhands/events/stream.py | 126 +++++++---- tests/unit/test_agent_controller.py | 333 ++++++++++++++++++++++++++++ 2 files changed, 414 insertions(+), 45 deletions(-) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 938269822a7a..a6ac46600aa7 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -125,36 +125,41 @@ def close(self): self._queue.get() def _clean_up_subscriber(self, subscriber_id: str, callback_id: str): - if subscriber_id not in self._subscribers: + """Clean up a subscriber's callback, ensuring proper shutdown of resources. + + This method must be called with self._lock held to ensure thread safety. + """ + # Remove from subscribers first to prevent new events from being processed + if subscriber_id in self._subscribers: + if callback_id in self._subscribers[subscriber_id]: + del self._subscribers[subscriber_id][callback_id] + else: + logger.warning(f'Callback not found during cleanup: {callback_id}') + return + else: logger.warning(f'Subscriber not found during cleanup: {subscriber_id}') return - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during cleanup: {callback_id}') - return - if ( - subscriber_id in self._thread_loops - and callback_id in self._thread_loops[subscriber_id] - ): + + # Clean up the event loop + if subscriber_id in self._thread_loops and callback_id in self._thread_loops[subscriber_id]: loop = self._thread_loops[subscriber_id][callback_id] try: loop.stop() loop.close() except Exception as e: - logger.warning( - f'Error closing loop for {subscriber_id}/{callback_id}: {e}' - ) + logger.warning(f'Error closing loop for {subscriber_id}/{callback_id}: {e}') del self._thread_loops[subscriber_id][callback_id] - if ( - subscriber_id in self._thread_pools - and callback_id in self._thread_pools[subscriber_id] - ): + # Clean up the thread pool + if subscriber_id in self._thread_pools and callback_id in self._thread_pools[subscriber_id]: pool = self._thread_pools[subscriber_id][callback_id] - pool.shutdown() + try: + # Use shutdown(wait=False) to allow existing tasks to complete + pool.shutdown(wait=False) + except Exception as e: + logger.warning(f'Error shutting down pool for {subscriber_id}/{callback_id}: {e}') del self._thread_pools[subscriber_id][callback_id] - del self._subscribers[subscriber_id][callback_id] - def _get_filename_for_id(self, id: int) -> str: return get_conversation_event_filename(self.sid, id) @@ -236,30 +241,46 @@ def get_latest_event_id(self) -> int: def subscribe( self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str ): - initializer = partial(self._init_thread_loop, subscriber_id, callback_id) - pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) - if subscriber_id not in self._subscribers: - self._subscribers[subscriber_id] = {} - self._thread_pools[subscriber_id] = {} - - if callback_id in self._subscribers[subscriber_id]: - raise ValueError( - f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' - ) + """Subscribe a callback to receive events. + + Args: + subscriber_id: The type of subscriber (e.g., RUNTIME, AGENT_CONTROLLER) + callback: The callback function to be invoked for each event + callback_id: A unique ID for this callback instance + + Raises: + ValueError: If the callback_id is already in use for this subscriber + """ + with self._lock: # Use the same lock as unsubscribe to ensure thread safety + initializer = partial(self._init_thread_loop, subscriber_id, callback_id) + pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) + + if subscriber_id not in self._subscribers: + self._subscribers[subscriber_id] = {} + self._thread_pools[subscriber_id] = {} + self._thread_loops[subscriber_id] = {} + + if callback_id in self._subscribers[subscriber_id]: + pool.shutdown(wait=False) # Clean up the pool we just created + raise ValueError( + f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' + ) - self._subscribers[subscriber_id][callback_id] = callback - self._thread_pools[subscriber_id][callback_id] = pool + # Add the callback and its resources atomically + self._subscribers[subscriber_id][callback_id] = callback + self._thread_pools[subscriber_id][callback_id] = pool def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str): - if subscriber_id not in self._subscribers: - logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') - return + with self._lock: # Use the same lock as add_event to ensure thread safety + if subscriber_id not in self._subscribers: + logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') + return - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during unsubscribe: {callback_id}') - return + if callback_id not in self._subscribers[subscriber_id]: + logger.warning(f'Callback not found during unsubscribe: {callback_id}') + return - self._clean_up_subscriber(subscriber_id, callback_id) + self._clean_up_subscriber(subscriber_id, callback_id) def add_event(self, event: Event, source: EventSource): if hasattr(event, '_id') and event.id is not None: @@ -310,14 +331,29 @@ async def _process_queue(self): except queue.Empty: continue - # pass each event to each callback in order - for key in sorted(self._subscribers.keys()): - callbacks = self._subscribers[key] - for callback_id in callbacks: - callback = callbacks[callback_id] - pool = self._thread_pools[key][callback_id] - future = pool.submit(callback, event) - future.add_done_callback(self._make_error_handler(callback_id, key)) + # Take a snapshot of subscribers to prevent modification during iteration + with self._lock: + subscribers = { + key: { + cid: (cb, self._thread_pools[key][cid]) + for cid, cb in callbacks.items() + } + for key, callbacks in self._subscribers.items() + } + + # Process the event with the snapshot + for key in sorted(subscribers.keys()): + callbacks = subscribers[key] + for callback_id, (callback, pool) in callbacks.items(): + try: + future = pool.submit(callback, event) + future.add_done_callback(self._make_error_handler(callback_id, key)) + except RuntimeError as e: + if "cannot schedule new futures after shutdown" in str(e): + # Pool was shut down, skip this callback + logger.debug(f"Skipping callback {callback_id} for {key} - pool is shut down") + continue + raise def _make_error_handler(self, callback_id: str, subscriber_id: str): def _handle_callback_error(fut): diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index a0a350d4d887..2715205628a5 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,4 +1,5 @@ import asyncio +import logging from unittest.mock import ANY, AsyncMock, MagicMock from uuid import uuid4 @@ -360,6 +361,338 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream): await controller.close() +@pytest.mark.asyncio +async def test_subscriber_behavior(temp_dir): + """Test the behavior of subscribers, especially around unsubscribe and resubscribe scenarios.""" + # Create a real event stream with a file store + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + # Create a list to track callback invocations + callback_invocations = [] + + # Create multiple callbacks + def callback1(event): + callback_invocations.append(('callback1', event)) + + def callback2(event): + callback_invocations.append(('callback2', event)) + + # Add multiple subscribers + callback1_id = str(uuid4()) + callback2_id = str(uuid4()) + + # Subscribe both callbacks + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback1, callback1_id) + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback2, callback2_id) + + # Add a test event + test_event = MessageAction(content='Test message') + event_stream.add_event(test_event, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Both callbacks should have received the event + assert len(callback_invocations) == 2 + assert ('callback1', test_event) in callback_invocations + assert ('callback2', test_event) in callback_invocations + + # Clear invocations + callback_invocations.clear() + + # Unsubscribe one callback + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + + # Add another test event + test_event2 = MessageAction(content='Test message 2') + event_stream.add_event(test_event2, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Only callback2 should have received the event + assert len(callback_invocations) == 1 + assert ('callback2', test_event2) in callback_invocations + + # Try to unsubscribe again - should log warning but not error + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + + # Resubscribe callback1 + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback1, callback1_id) + + # Clear invocations + callback_invocations.clear() + + # Add another test event + test_event3 = MessageAction(content='Test message 3') + event_stream.add_event(test_event3, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Both callbacks should receive the event again + assert len(callback_invocations) == 2 + assert ('callback1', test_event3) in callback_invocations + assert ('callback2', test_event3) in callback_invocations + + # Unsubscribe a non-existent callback - should log warning + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, str(uuid4())) + + # Unsubscribe from a non-existent subscriber - should log warning + event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, callback1_id) + + # Clean up + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback2_id) + event_stream.close() + + +@pytest.mark.asyncio +async def test_subscriber_stress(temp_dir): + """Stress test subscriber behavior to try to replicate the bug where subscriber warning appears + even though the callback still exists and is being invoked.""" + # Create a real event stream with a file store + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + # Create a list to track callback invocations + callback_invocations = [] + + # Create a callback that simulates some processing time + async def async_callback(event): + await asyncio.sleep(0.01) # Simulate some async work + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + # Create multiple subscriber IDs + subscriber_ids = [str(uuid4()) for _ in range(5)] + active_callbacks = {sid: [] for sid in subscriber_ids} # Track active callback IDs for each subscriber + + try: + # Subscribe and unsubscribe in rapid succession + for i in range(10): # Do 10 rounds of subscribe/unsubscribe + for sid in subscriber_ids: + # Generate a new callback ID for each subscription + callback_id = str(uuid4()) + active_callbacks[sid].append(callback_id) + + # Subscribe with new callback ID + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, callback_id) + + # Add a test event + test_event = MessageAction(content=f'Test message {i}-{sid}') + event_stream.add_event(test_event, EventSource.USER) + + # Give minimal time for event processing + await asyncio.sleep(0.001) + + # Unsubscribe the previous callback if it exists + if len(active_callbacks[sid]) > 1: + old_callback_id = active_callbacks[sid].pop(0) + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, old_callback_id) + + # Add another test event + test_event2 = MessageAction(content=f'Test message {i}-{sid}-2') + event_stream.add_event(test_event2, EventSource.USER) + + # Occasionally try to unsubscribe a non-existent callback + if i % 3 == 0: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, str(uuid4())) + + # Give time for all events to be processed + await asyncio.sleep(1) + + # Verify that we received events + # We should have 2 events per iteration per subscriber + expected_events = 10 * len(subscriber_ids) * 2 + assert len(callback_invocations) >= expected_events * 0.9, f"Expected at least 90% of {expected_events} events, got {len(callback_invocations)}" + + finally: + # Clean up all callbacks + for sid in subscriber_ids: + for callback_id in active_callbacks[sid]: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback_id) + event_stream.close() + + +@pytest.mark.asyncio +async def test_subscriber_unsubscribe_bug(temp_dir): + """Test the specific case where a callback continues to be invoked even after + getting 'Subscriber not found during unsubscribe' warning.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + + # Create a callback that takes some time to complete + async def async_callback(event): + await asyncio.sleep(0.05) # Long enough to ensure we can test during execution + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + callback_id = str(uuid4()) + + # Subscribe the callback + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, callback_id) + + # Add an event to start the callback processing + test_event1 = MessageAction(content='Test message 1') + event_stream.add_event(test_event1, EventSource.USER) + + # Give a tiny bit of time for the event to be queued but not processed + await asyncio.sleep(0.01) + + # Try to unsubscribe while the callback is still processing + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback_id) + + # Add another event - this should not trigger the callback if unsubscribe worked + test_event2 = MessageAction(content='Test message 2') + event_stream.add_event(test_event2, EventSource.USER) + + # Wait for all processing to complete + await asyncio.sleep(0.2) + + # Clean up + event_stream.close() + + # Check the invocations + print(f"Callback invocations: {callback_invocations}") + assert len(callback_invocations) == 1, ( + f"Expected only 1 invocation (from first event), but got {len(callback_invocations)}. " + f"This indicates the callback was still being called after unsubscribe." + ) + assert callback_invocations[0][1] == test_event1, ( + "Expected only the first event to be processed" + ) + + +@pytest.mark.asyncio +async def test_subscriber_unsubscribe_concurrent(temp_dir): + """Test concurrent subscribe/unsubscribe operations to try to trigger the bug + where callbacks remain active after unsubscribe.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + unsubscribe_warnings = [] + + # Create a callback that takes some time to complete + async def async_callback(event): + await asyncio.sleep(0.05) # Long enough to create overlap + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + # Create a custom logger to capture warnings + class WarningCaptureHandler(logging.Handler): + def emit(self, record): + if "Subscriber not found during unsubscribe" in record.getMessage(): + unsubscribe_warnings.append(record.getMessage()) + + logger = logging.getLogger('openhands.events.stream') + handler = WarningCaptureHandler() + logger.addHandler(handler) + + try: + # Create multiple subscribers and callbacks + subscriber_ids = [EventStreamSubscriber.RUNTIME, EventStreamSubscriber.AGENT_CONTROLLER] + callback_ids = {sid: [] for sid in subscriber_ids} + + for i in range(5): # 5 rounds of subscribe/unsubscribe + for subscriber_id in subscriber_ids: + # Subscribe with a new callback ID + callback_id = str(uuid4()) + callback_ids[subscriber_id].append(callback_id) + event_stream.subscribe(subscriber_id, callback, callback_id) + + # Add an event + test_event = MessageAction(content=f'Test message {i}-{subscriber_id}') + event_stream.add_event(test_event, EventSource.USER) + + # Small delay to allow some overlap + await asyncio.sleep(0.01) + + # Try to unsubscribe the previous callback if it exists + if len(callback_ids[subscriber_id]) > 1: + old_id = callback_ids[subscriber_id][-2] + event_stream.unsubscribe(subscriber_id, old_id) + + # Add another event + test_event2 = MessageAction(content=f'Test message {i}-{subscriber_id}-2') + event_stream.add_event(test_event2, EventSource.USER) + + # Wait for all processing to complete + await asyncio.sleep(0.5) + + # Check for any cases where we got a warning but the callback was still invoked + print(f"Unsubscribe warnings: {unsubscribe_warnings}") + print(f"Callback invocations: {len(callback_invocations)}") + + # If we got any "Subscriber not found" warnings, we should check that those + # callbacks were not invoked after the warning + if unsubscribe_warnings: + # The number of invocations should be less than the total number of events + total_events = len(subscriber_ids) * 5 * 2 # subscribers * rounds * events per round + assert len(callback_invocations) < total_events, ( + f"Got {len(callback_invocations)} invocations for {total_events} events " + f"even though some callbacks were unsubscribed" + ) + + finally: + # Clean up + for subscriber_id in subscriber_ids: + for callback_id in callback_ids[subscriber_id]: + event_stream.unsubscribe(subscriber_id, callback_id) + event_stream.close() + logger.removeHandler(handler) + + +@pytest.mark.asyncio +async def test_subscriber_race_condition(temp_dir): + """Test specifically for race conditions in subscriber management.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + + # Create a callback that takes some time to complete + async def slow_callback(event): + await asyncio.sleep(0.05) # Long enough to create potential race conditions + callback_invocations.append(('slow_callback', event)) + + def callback(event): + asyncio.run(slow_callback(event)) + + # Create tasks to subscribe/unsubscribe concurrently + async def subscribe_unsubscribe(sid: str): + for _ in range(5): + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, sid) + test_event = MessageAction(content=f'Test message for {sid}') + event_stream.add_event(test_event, EventSource.USER) + await asyncio.sleep(0.01) # Small delay to increase chance of race condition + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, sid) + await asyncio.sleep(0.01) + + # Run multiple subscribe/unsubscribe operations concurrently + subscriber_ids = [str(uuid4()) for _ in range(3)] + tasks = [subscribe_unsubscribe(sid) for sid in subscriber_ids] + await asyncio.gather(*tasks) + + # Give time for all events to be processed + await asyncio.sleep(1) + + # Clean up + for sid in subscriber_ids: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, sid) + event_stream.close() + + @pytest.mark.asyncio async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream): """Test reset() when there's a pending action with tool call metadata but no observation."""