1
1
import asyncio
2
2
import itertools
3
3
import threading
4
- from collections .abc import Callable , Iterable
4
+ from collections .abc import AsyncIterator , Callable , Iterable
5
5
from concurrent .futures import Future
6
6
from queue import Full
7
7
from typing import Any , TypeVar
@@ -362,10 +362,11 @@ def assert_running_count_plan_produces_ordered_worker_and_data_events(
362
362
events = []
363
363
364
364
async def collect_events ():
365
- async for event in take_events_from_streams (
365
+ events_iterator = take_events_from_streams (
366
366
event_streams ,
367
367
lambda _ : next (count ) >= len (expected_events ) - 1 ,
368
- ):
368
+ )
369
+ async for event in events_iterator :
369
370
events .append (event )
370
371
if len (events ) >= len (expected_events ):
371
372
break
@@ -378,7 +379,7 @@ async def collect_events():
378
379
except asyncio .TimeoutError :
379
380
pytest .fail (f"Test timed out after { timeout } seconds while waiting for events." )
380
381
381
- _compare_events (expected_events , task_id , results )
382
+ _compare_events (expected_events , task_id , events )
382
383
383
384
384
385
def _compare_events (
@@ -422,46 +423,43 @@ def on_event(event: E, event_id: str | None) -> None:
422
423
return future
423
424
424
425
425
- def take_events_from_streams (
426
- streams : list [EventStream [Any , int ] ],
426
+ async def take_events_from_streams (
427
+ streams : list [" EventStream[Any, Any]" ],
427
428
cutoff_predicate : Callable [[Any ], bool ],
428
- ) -> Future [list [Any ]]:
429
- """Returns a collated list of futures for events in numerous event streams.
430
-
431
- The support for generic and algebraic types doesn't appear to extend to
432
- taking an arbitrary list of concrete types with single but differing
433
- generic arguments while also maintaining the generality of the argument
434
- types.
435
-
436
- The type for streams will be any combination of event streams each of a
437
- given event type, where the event type is generic:
438
-
439
- List[
440
- Union[
441
- EventStream[WorkerEvent, int],
442
- EventStream[DataEvent, int],
443
- EventStream[ProgressEvent, int]
444
- ]
445
- ]
429
+ ) -> AsyncIterator [Any ]:
430
+ """Returns an async generator that yields events from multiple event streams."""
446
431
447
- """
448
- events : list [Any ] = []
449
- future : Future [list [Any ]] = Future ()
432
+ event_queue = asyncio .Queue () # Queue to store events from the streams
433
+ cutoff_reached = asyncio .Event () # Event to signal when to stop listening
450
434
451
435
def on_event (event : Any , event_id : str | None ) -> None :
452
- print ( event )
453
- events . append (event )
436
+ """Callback for events."""
437
+ event_queue . put_nowait (event ) # Add the event to the async queue
454
438
if cutoff_predicate (event ):
455
- future .set_result (events )
439
+ cutoff_reached .set () # Signal the cutoff event
440
+
441
+ subscriptions = []
456
442
443
+ # Subscribe to all the event streams
457
444
for stream in streams :
458
445
sub = stream .subscribe (on_event )
446
+ subscriptions .append ((stream , sub ))
459
447
460
- def callback (unused : Future [list [Any ]], stream = stream , sub = sub ):
461
- stream .unsubscribe (sub )
448
+ async def event_producer ():
449
+ """Asynchronously yield events from the queue."""
450
+ while not cutoff_reached .is_set ():
451
+ event = await event_queue .get () # Wait for the next event
452
+ yield event
462
453
463
- future .add_done_callback (callback )
464
- return future
454
+ try :
455
+ # Yield events using the event_producer async generator
456
+ async for event in event_producer ():
457
+ yield event
458
+
459
+ finally :
460
+ # Ensure we unsubscribe from all streams once done
461
+ for stream , sub in subscriptions :
462
+ stream .unsubscribe (sub )
465
463
466
464
467
465
@pytest .mark .parametrize (
0 commit comments