Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix race conditions in EventStream subscriber management #6904

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 81 additions & 45 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,36 +125,41 @@ def close(self):
self._queue.get()

def _clean_up_subscriber(self, subscriber_id: str, callback_id: str):
if subscriber_id not in self._subscribers:
"""Clean up a subscriber's callback, ensuring proper shutdown of resources.

This method must be called with self._lock held to ensure thread safety.
"""
# Remove from subscribers first to prevent new events from being processed
if subscriber_id in self._subscribers:
if callback_id in self._subscribers[subscriber_id]:
del self._subscribers[subscriber_id][callback_id]
else:
logger.warning(f'Callback not found during cleanup: {callback_id}')
return
else:
logger.warning(f'Subscriber not found during cleanup: {subscriber_id}')
return
if callback_id not in self._subscribers[subscriber_id]:
logger.warning(f'Callback not found during cleanup: {callback_id}')
return
if (
subscriber_id in self._thread_loops
and callback_id in self._thread_loops[subscriber_id]
):

# Clean up the event loop
if subscriber_id in self._thread_loops and callback_id in self._thread_loops[subscriber_id]:
loop = self._thread_loops[subscriber_id][callback_id]
try:
loop.stop()
loop.close()
except Exception as e:
logger.warning(
f'Error closing loop for {subscriber_id}/{callback_id}: {e}'
)
logger.warning(f'Error closing loop for {subscriber_id}/{callback_id}: {e}')
del self._thread_loops[subscriber_id][callback_id]

if (
subscriber_id in self._thread_pools
and callback_id in self._thread_pools[subscriber_id]
):
# Clean up the thread pool
if subscriber_id in self._thread_pools and callback_id in self._thread_pools[subscriber_id]:
pool = self._thread_pools[subscriber_id][callback_id]
pool.shutdown()
try:
# Use shutdown(wait=False) to allow existing tasks to complete
pool.shutdown(wait=False)
except Exception as e:
logger.warning(f'Error shutting down pool for {subscriber_id}/{callback_id}: {e}')
del self._thread_pools[subscriber_id][callback_id]

del self._subscribers[subscriber_id][callback_id]

def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)

Expand Down Expand Up @@ -236,30 +241,46 @@ def get_latest_event_id(self) -> int:
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}

if callback_id in self._subscribers[subscriber_id]:
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)
"""Subscribe a callback to receive events.

Args:
subscriber_id: The type of subscriber (e.g., RUNTIME, AGENT_CONTROLLER)
callback: The callback function to be invoked for each event
callback_id: A unique ID for this callback instance

Raises:
ValueError: If the callback_id is already in use for this subscriber
"""
with self._lock: # Use the same lock as unsubscribe to ensure thread safety
initializer = partial(self._init_thread_loop, subscriber_id, callback_id)
pool = ThreadPoolExecutor(max_workers=1, initializer=initializer)

if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}
self._thread_loops[subscriber_id] = {}

if callback_id in self._subscribers[subscriber_id]:
pool.shutdown(wait=False) # Clean up the pool we just created
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)

self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool
# Add the callback and its resources atomically
self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool

def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
return
with self._lock: # Use the same lock as add_event to ensure thread safety
if subscriber_id not in self._subscribers:
logger.warning(f'Subscriber not found during unsubscribe: {subscriber_id}')
return

if callback_id not in self._subscribers[subscriber_id]:
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
return
if callback_id not in self._subscribers[subscriber_id]:
logger.warning(f'Callback not found during unsubscribe: {callback_id}')
return

self._clean_up_subscriber(subscriber_id, callback_id)
self._clean_up_subscriber(subscriber_id, callback_id)

def add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
Expand Down Expand Up @@ -310,14 +331,29 @@ async def _process_queue(self):
except queue.Empty:
continue

# pass each event to each callback in order
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
future = pool.submit(callback, event)
future.add_done_callback(self._make_error_handler(callback_id, key))
# Take a snapshot of subscribers to prevent modification during iteration
with self._lock:
subscribers = {
key: {
cid: (cb, self._thread_pools[key][cid])
for cid, cb in callbacks.items()
}
for key, callbacks in self._subscribers.items()
}

# Process the event with the snapshot
for key in sorted(subscribers.keys()):
callbacks = subscribers[key]
for callback_id, (callback, pool) in callbacks.items():
try:
future = pool.submit(callback, event)
future.add_done_callback(self._make_error_handler(callback_id, key))
except RuntimeError as e:
if "cannot schedule new futures after shutdown" in str(e):
# Pool was shut down, skip this callback
logger.debug(f"Skipping callback {callback_id} for {key} - pool is shut down")
continue
raise

def _make_error_handler(self, callback_id: str, subscriber_id: str):
def _handle_callback_error(fut):
Expand Down
Loading
Loading