Skip to content

Commit db8c968

Browse files
committed
tests compile but times out
1 parent 98204c8 commit db8c968

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

tests/unit_tests/worker/test_task_worker.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import itertools
33
import threading
4-
from collections.abc import Callable, Iterable
4+
from collections.abc import AsyncIterator, Callable, Iterable
55
from concurrent.futures import Future
66
from queue import Full
77
from typing import Any, TypeVar
@@ -362,10 +362,11 @@ def assert_running_count_plan_produces_ordered_worker_and_data_events(
362362
events = []
363363

364364
async def collect_events():
365-
async for event in take_events_from_streams(
365+
events_iterator = take_events_from_streams(
366366
event_streams,
367367
lambda _: next(count) >= len(expected_events) - 1,
368-
):
368+
)
369+
async for event in events_iterator:
369370
events.append(event)
370371
if len(events) >= len(expected_events):
371372
break
@@ -378,7 +379,7 @@ async def collect_events():
378379
except asyncio.TimeoutError:
379380
pytest.fail(f"Test timed out after {timeout} seconds while waiting for events.")
380381

381-
_compare_events(expected_events, task_id, results)
382+
_compare_events(expected_events, task_id, events)
382383

383384

384385
def _compare_events(
@@ -422,46 +423,43 @@ def on_event(event: E, event_id: str | None) -> None:
422423
return future
423424

424425

425-
def take_events_from_streams(
426-
streams: list[EventStream[Any, int]],
426+
async def take_events_from_streams(
427+
streams: list["EventStream[Any, Any]"],
427428
cutoff_predicate: Callable[[Any], bool],
428-
) -> Future[list[Any]]:
429-
"""Returns a collated list of futures for events in numerous event streams.
430-
431-
The support for generic and algebraic types doesn't appear to extend to
432-
taking an arbitrary list of concrete types with single but differing
433-
generic arguments while also maintaining the generality of the argument
434-
types.
435-
436-
The type for streams will be any combination of event streams each of a
437-
given event type, where the event type is generic:
438-
439-
List[
440-
Union[
441-
EventStream[WorkerEvent, int],
442-
EventStream[DataEvent, int],
443-
EventStream[ProgressEvent, int]
444-
]
445-
]
429+
) -> AsyncIterator[Any]:
430+
"""Returns an async generator that yields events from multiple event streams."""
446431

447-
"""
448-
events: list[Any] = []
449-
future: Future[list[Any]] = Future()
432+
event_queue = asyncio.Queue() # Queue to store events from the streams
433+
cutoff_reached = asyncio.Event() # Event to signal when to stop listening
450434

451435
def on_event(event: Any, event_id: str | None) -> None:
452-
print(event)
453-
events.append(event)
436+
"""Callback for events."""
437+
event_queue.put_nowait(event) # Add the event to the async queue
454438
if cutoff_predicate(event):
455-
future.set_result(events)
439+
cutoff_reached.set() # Signal the cutoff event
440+
441+
subscriptions = []
456442

443+
# Subscribe to all the event streams
457444
for stream in streams:
458445
sub = stream.subscribe(on_event)
446+
subscriptions.append((stream, sub))
459447

460-
def callback(unused: Future[list[Any]], stream=stream, sub=sub):
461-
stream.unsubscribe(sub)
448+
async def event_producer():
449+
"""Asynchronously yield events from the queue."""
450+
while not cutoff_reached.is_set():
451+
event = await event_queue.get() # Wait for the next event
452+
yield event
462453

463-
future.add_done_callback(callback)
464-
return future
454+
try:
455+
# Yield events using the event_producer async generator
456+
async for event in event_producer():
457+
yield event
458+
459+
finally:
460+
# Ensure we unsubscribe from all streams once done
461+
for stream, sub in subscriptions:
462+
stream.unsubscribe(sub)
465463

466464

467465
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)