diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 938269822a7a..a6ac46600aa7 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -125,36 +125,41 @@ def close(self): self._queue.get() def _clean_up_subscriber(self, subscriber_id: str, callback_id: str): - if subscriber_id not in self._subscribers: + """Clean up a subscriber's callback, ensuring proper shutdown of resources. + + This method must be called with self._lock held to ensure thread safety. + """ + # Remove from subscribers first to prevent new events from being processed + if subscriber_id in self._subscribers: + if callback_id in self._subscribers[subscriber_id]: + del self._subscribers[subscriber_id][callback_id] + else: + logger.warning(f'Callback not found during cleanup: {callback_id}') + return + else: logger.warning(f'Subscriber not found during cleanup: {subscriber_id}') return - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during cleanup: {callback_id}') - return - if ( - subscriber_id in self._thread_loops - and callback_id in self._thread_loops[subscriber_id] - ): + + # Clean up the event loop + if subscriber_id in self._thread_loops and callback_id in self._thread_loops[subscriber_id]: loop = self._thread_loops[subscriber_id][callback_id] try: loop.stop() loop.close() except Exception as e: - logger.warning( - f'Error closing loop for {subscriber_id}/{callback_id}: {e}' - ) + logger.warning(f'Error closing loop for {subscriber_id}/{callback_id}: {e}') del self._thread_loops[subscriber_id][callback_id] - if ( - subscriber_id in self._thread_pools - and callback_id in self._thread_pools[subscriber_id] - ): + # Clean up the thread pool + if subscriber_id in self._thread_pools and callback_id in self._thread_pools[subscriber_id]: pool = self._thread_pools[subscriber_id][callback_id] - pool.shutdown() + try: + # Use shutdown(wait=False) to allow existing tasks to complete + pool.shutdown(wait=False) + except Exception as e: + logger.warning(f'Error shutting down pool for {subscriber_id}/{callback_id}: {e}') del self._thread_pools[subscriber_id][callback_id] - del self._subscribers[subscriber_id][callback_id] - def _get_filename_for_id(self, id: int) -> str: return get_conversation_event_filename(self.sid, id) @@ -236,30 +241,46 @@ def get_latest_event_id(self) -> int: def subscribe( self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str ): - initializer = partial(self._init_thread_loop, subscriber_id, callback_id) - pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) - if subscriber_id not in self._subscribers: - self._subscribers[subscriber_id] = {} - self._thread_pools[subscriber_id] = {} - - if callback_id in self._subscribers[subscriber_id]: - raise ValueError( - f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' - ) + """Subscribe a callback to receive events. + + Args: + subscriber_id: The type of subscriber (e.g., RUNTIME, AGENT_CONTROLLER) + callback: The callback function to be invoked for each event + callback_id: A unique ID for this callback instance + + Raises: + ValueError: If the callback_id is already in use for this subscriber + """ + with self._lock: # Use the same lock as unsubscribe to ensure thread safety + initializer = partial(self._init_thread_loop, subscriber_id, callback_id) + pool = ThreadPoolExecutor(max_workers=1, initializer=initializer) + + if subscriber_id not in self._subscribers: + self._subscribers[subscriber_id] = {} + self._thread_pools[subscriber_id] = {} + self._thread_loops[subscriber_id] = {} + + if callback_id in self._subscribers[subscriber_id]: + pool.shutdown(wait=False) # Clean up the pool we just created + raise ValueError( + f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}' + ) - self._subscribers[subscriber_id][callback_id] = callback - self._thread_pools[subscriber_id][callback_id] = pool + # Add the callback and its resources atomically + self._subscribers[subscriber_id][callback_id] = callback + self._thread_pools[subscriber_id][callback_id] = pool def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str): - if subscriber_id not in self._subscribers: - logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') - return + with self._lock: # Use the same lock as add_event to ensure thread safety + if subscriber_id not in self._subscribers: + logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}') + return - if callback_id not in self._subscribers[subscriber_id]: - logger.warning(f'Callback not found during unsubscribe: {callback_id}') - return + if callback_id not in self._subscribers[subscriber_id]: + logger.warning(f'Callback not found during unsubscribe: {callback_id}') + return - self._clean_up_subscriber(subscriber_id, callback_id) + self._clean_up_subscriber(subscriber_id, callback_id) def add_event(self, event: Event, source: EventSource): if hasattr(event, '_id') and event.id is not None: @@ -310,14 +331,29 @@ async def _process_queue(self): except queue.Empty: continue - # pass each event to each callback in order - for key in sorted(self._subscribers.keys()): - callbacks = self._subscribers[key] - for callback_id in callbacks: - callback = callbacks[callback_id] - pool = self._thread_pools[key][callback_id] - future = pool.submit(callback, event) - future.add_done_callback(self._make_error_handler(callback_id, key)) + # Take a snapshot of subscribers to prevent modification during iteration + with self._lock: + subscribers = { + key: { + cid: (cb, self._thread_pools[key][cid]) + for cid, cb in callbacks.items() + } + for key, callbacks in self._subscribers.items() + } + + # Process the event with the snapshot + for key in sorted(subscribers.keys()): + callbacks = subscribers[key] + for callback_id, (callback, pool) in callbacks.items(): + try: + future = pool.submit(callback, event) + future.add_done_callback(self._make_error_handler(callback_id, key)) + except RuntimeError as e: + if "cannot schedule new futures after shutdown" in str(e): + # Pool was shut down, skip this callback + logger.debug(f"Skipping callback {callback_id} for {key} - pool is shut down") + continue + raise def _make_error_handler(self, callback_id: str, subscriber_id: str): def _handle_callback_error(fut): diff --git a/tests/unit/test_agent_controller.py b/tests/unit/test_agent_controller.py index a0a350d4d887..2715205628a5 100644 --- a/tests/unit/test_agent_controller.py +++ b/tests/unit/test_agent_controller.py @@ -1,4 +1,5 @@ import asyncio +import logging from unittest.mock import ANY, AsyncMock, MagicMock from uuid import uuid4 @@ -360,6 +361,338 @@ async def test_step_max_budget_headless(mock_agent, mock_event_stream): await controller.close() +@pytest.mark.asyncio +async def test_subscriber_behavior(temp_dir): + """Test the behavior of subscribers, especially around unsubscribe and resubscribe scenarios.""" + # Create a real event stream with a file store + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + # Create a list to track callback invocations + callback_invocations = [] + + # Create multiple callbacks + def callback1(event): + callback_invocations.append(('callback1', event)) + + def callback2(event): + callback_invocations.append(('callback2', event)) + + # Add multiple subscribers + callback1_id = str(uuid4()) + callback2_id = str(uuid4()) + + # Subscribe both callbacks + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback1, callback1_id) + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback2, callback2_id) + + # Add a test event + test_event = MessageAction(content='Test message') + event_stream.add_event(test_event, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Both callbacks should have received the event + assert len(callback_invocations) == 2 + assert ('callback1', test_event) in callback_invocations + assert ('callback2', test_event) in callback_invocations + + # Clear invocations + callback_invocations.clear() + + # Unsubscribe one callback + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + + # Add another test event + test_event2 = MessageAction(content='Test message 2') + event_stream.add_event(test_event2, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Only callback2 should have received the event + assert len(callback_invocations) == 1 + assert ('callback2', test_event2) in callback_invocations + + # Try to unsubscribe again - should log warning but not error + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + + # Resubscribe callback1 + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback1, callback1_id) + + # Clear invocations + callback_invocations.clear() + + # Add another test event + test_event3 = MessageAction(content='Test message 3') + event_stream.add_event(test_event3, EventSource.USER) + + # Give time for event processing + await asyncio.sleep(0.1) + + # Both callbacks should receive the event again + assert len(callback_invocations) == 2 + assert ('callback1', test_event3) in callback_invocations + assert ('callback2', test_event3) in callback_invocations + + # Unsubscribe a non-existent callback - should log warning + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, str(uuid4())) + + # Unsubscribe from a non-existent subscriber - should log warning + event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER, callback1_id) + + # Clean up + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback1_id) + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback2_id) + event_stream.close() + + +@pytest.mark.asyncio +async def test_subscriber_stress(temp_dir): + """Stress test subscriber behavior to try to replicate the bug where subscriber warning appears + even though the callback still exists and is being invoked.""" + # Create a real event stream with a file store + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + # Create a list to track callback invocations + callback_invocations = [] + + # Create a callback that simulates some processing time + async def async_callback(event): + await asyncio.sleep(0.01) # Simulate some async work + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + # Create multiple subscriber IDs + subscriber_ids = [str(uuid4()) for _ in range(5)] + active_callbacks = {sid: [] for sid in subscriber_ids} # Track active callback IDs for each subscriber + + try: + # Subscribe and unsubscribe in rapid succession + for i in range(10): # Do 10 rounds of subscribe/unsubscribe + for sid in subscriber_ids: + # Generate a new callback ID for each subscription + callback_id = str(uuid4()) + active_callbacks[sid].append(callback_id) + + # Subscribe with new callback ID + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, callback_id) + + # Add a test event + test_event = MessageAction(content=f'Test message {i}-{sid}') + event_stream.add_event(test_event, EventSource.USER) + + # Give minimal time for event processing + await asyncio.sleep(0.001) + + # Unsubscribe the previous callback if it exists + if len(active_callbacks[sid]) > 1: + old_callback_id = active_callbacks[sid].pop(0) + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, old_callback_id) + + # Add another test event + test_event2 = MessageAction(content=f'Test message {i}-{sid}-2') + event_stream.add_event(test_event2, EventSource.USER) + + # Occasionally try to unsubscribe a non-existent callback + if i % 3 == 0: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, str(uuid4())) + + # Give time for all events to be processed + await asyncio.sleep(1) + + # Verify that we received events + # We should have 2 events per iteration per subscriber + expected_events = 10 * len(subscriber_ids) * 2 + assert len(callback_invocations) >= expected_events * 0.9, f"Expected at least 90% of {expected_events} events, got {len(callback_invocations)}" + + finally: + # Clean up all callbacks + for sid in subscriber_ids: + for callback_id in active_callbacks[sid]: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback_id) + event_stream.close() + + +@pytest.mark.asyncio +async def test_subscriber_unsubscribe_bug(temp_dir): + """Test the specific case where a callback continues to be invoked even after + getting 'Subscriber not found during unsubscribe' warning.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + + # Create a callback that takes some time to complete + async def async_callback(event): + await asyncio.sleep(0.05) # Long enough to ensure we can test during execution + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + callback_id = str(uuid4()) + + # Subscribe the callback + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, callback_id) + + # Add an event to start the callback processing + test_event1 = MessageAction(content='Test message 1') + event_stream.add_event(test_event1, EventSource.USER) + + # Give a tiny bit of time for the event to be queued but not processed + await asyncio.sleep(0.01) + + # Try to unsubscribe while the callback is still processing + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, callback_id) + + # Add another event - this should not trigger the callback if unsubscribe worked + test_event2 = MessageAction(content='Test message 2') + event_stream.add_event(test_event2, EventSource.USER) + + # Wait for all processing to complete + await asyncio.sleep(0.2) + + # Clean up + event_stream.close() + + # Check the invocations + print(f"Callback invocations: {callback_invocations}") + assert len(callback_invocations) == 1, ( + f"Expected only 1 invocation (from first event), but got {len(callback_invocations)}. " + f"This indicates the callback was still being called after unsubscribe." + ) + assert callback_invocations[0][1] == test_event1, ( + "Expected only the first event to be processed" + ) + + +@pytest.mark.asyncio +async def test_subscriber_unsubscribe_concurrent(temp_dir): + """Test concurrent subscribe/unsubscribe operations to try to trigger the bug + where callbacks remain active after unsubscribe.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + unsubscribe_warnings = [] + + # Create a callback that takes some time to complete + async def async_callback(event): + await asyncio.sleep(0.05) # Long enough to create overlap + callback_invocations.append(('async_callback', event)) + + def callback(event): + asyncio.run(async_callback(event)) + + # Create a custom logger to capture warnings + class WarningCaptureHandler(logging.Handler): + def emit(self, record): + if "Subscriber not found during unsubscribe" in record.getMessage(): + unsubscribe_warnings.append(record.getMessage()) + + logger = logging.getLogger('openhands.events.stream') + handler = WarningCaptureHandler() + logger.addHandler(handler) + + try: + # Create multiple subscribers and callbacks + subscriber_ids = [EventStreamSubscriber.RUNTIME, EventStreamSubscriber.AGENT_CONTROLLER] + callback_ids = {sid: [] for sid in subscriber_ids} + + for i in range(5): # 5 rounds of subscribe/unsubscribe + for subscriber_id in subscriber_ids: + # Subscribe with a new callback ID + callback_id = str(uuid4()) + callback_ids[subscriber_id].append(callback_id) + event_stream.subscribe(subscriber_id, callback, callback_id) + + # Add an event + test_event = MessageAction(content=f'Test message {i}-{subscriber_id}') + event_stream.add_event(test_event, EventSource.USER) + + # Small delay to allow some overlap + await asyncio.sleep(0.01) + + # Try to unsubscribe the previous callback if it exists + if len(callback_ids[subscriber_id]) > 1: + old_id = callback_ids[subscriber_id][-2] + event_stream.unsubscribe(subscriber_id, old_id) + + # Add another event + test_event2 = MessageAction(content=f'Test message {i}-{subscriber_id}-2') + event_stream.add_event(test_event2, EventSource.USER) + + # Wait for all processing to complete + await asyncio.sleep(0.5) + + # Check for any cases where we got a warning but the callback was still invoked + print(f"Unsubscribe warnings: {unsubscribe_warnings}") + print(f"Callback invocations: {len(callback_invocations)}") + + # If we got any "Subscriber not found" warnings, we should check that those + # callbacks were not invoked after the warning + if unsubscribe_warnings: + # The number of invocations should be less than the total number of events + total_events = len(subscriber_ids) * 5 * 2 # subscribers * rounds * events per round + assert len(callback_invocations) < total_events, ( + f"Got {len(callback_invocations)} invocations for {total_events} events " + f"even though some callbacks were unsubscribed" + ) + + finally: + # Clean up + for subscriber_id in subscriber_ids: + for callback_id in callback_ids[subscriber_id]: + event_stream.unsubscribe(subscriber_id, callback_id) + event_stream.close() + logger.removeHandler(handler) + + +@pytest.mark.asyncio +async def test_subscriber_race_condition(temp_dir): + """Test specifically for race conditions in subscriber management.""" + file_store = InMemoryFileStore({}) + event_stream = EventStream(sid='test', file_store=file_store) + + callback_invocations = [] + + # Create a callback that takes some time to complete + async def slow_callback(event): + await asyncio.sleep(0.05) # Long enough to create potential race conditions + callback_invocations.append(('slow_callback', event)) + + def callback(event): + asyncio.run(slow_callback(event)) + + # Create tasks to subscribe/unsubscribe concurrently + async def subscribe_unsubscribe(sid: str): + for _ in range(5): + event_stream.subscribe(EventStreamSubscriber.RUNTIME, callback, sid) + test_event = MessageAction(content=f'Test message for {sid}') + event_stream.add_event(test_event, EventSource.USER) + await asyncio.sleep(0.01) # Small delay to increase chance of race condition + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, sid) + await asyncio.sleep(0.01) + + # Run multiple subscribe/unsubscribe operations concurrently + subscriber_ids = [str(uuid4()) for _ in range(3)] + tasks = [subscribe_unsubscribe(sid) for sid in subscriber_ids] + await asyncio.gather(*tasks) + + # Give time for all events to be processed + await asyncio.sleep(1) + + # Clean up + for sid in subscriber_ids: + event_stream.unsubscribe(EventStreamSubscriber.RUNTIME, sid) + event_stream.close() + + @pytest.mark.asyncio async def test_reset_with_pending_action_no_observation(mock_agent, mock_event_stream): """Test reset() when there's a pending action with tool call metadata but no observation."""