Skip to content

Commit

Permalink
refactor: replace ThreadPoolExecutor with simple Thread in EventStream
Browse files Browse the repository at this point in the history
  • Loading branch information
openhands-agent committed Dec 30, 2024
1 parent e79cf23 commit c9ecef9
Showing 1 changed file with 3 additions and 10 deletions.
13 changes: 3 additions & 10 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import field
from datetime import datetime
from enum import Enum
Expand Down Expand Up @@ -67,7 +66,6 @@ def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
self.sid = sid
self.file_store = file_store
self._queue: Queue[Event] = Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
self._queue_thread.start()
Expand All @@ -89,9 +87,6 @@ def __post_init__(self) -> None:
if id >= self._cur_id:
self._cur_id = id + 1

def _init_thread_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)
Expand Down Expand Up @@ -174,18 +169,15 @@ def get_latest_event_id(self) -> int:
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}

if callback_id in self._subscribers[subscriber_id]:
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)

self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool

def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
if subscriber_id not in self._subscribers:
Expand Down Expand Up @@ -226,8 +218,9 @@ async def _process_queue(self):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
pool.submit(callback, event)
thread = threading.Thread(target=callback, args=(event,))
thread.daemon = True
thread.start()

def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
Expand Down

0 comments on commit c9ecef9

Please sign in to comment.