From db8c9682f33afa1bef79116ecfde2a631444c7ed Mon Sep 17 00:00:00 2001 From: Stanislaw Malinowski Date: Mon, 14 Oct 2024 13:53:39 +0000 Subject: [PATCH] tests compile but times out --- tests/unit_tests/worker/test_task_worker.py | 66 ++++++++++----------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 386346af1..58066b52f 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -1,7 +1,7 @@ import asyncio import itertools import threading -from collections.abc import Callable, Iterable +from collections.abc import AsyncIterator, Callable, Iterable from concurrent.futures import Future from queue import Full from typing import Any, TypeVar @@ -362,10 +362,11 @@ def assert_running_count_plan_produces_ordered_worker_and_data_events( events = [] async def collect_events(): - async for event in take_events_from_streams( + events_iterator = take_events_from_streams( event_streams, lambda _: next(count) >= len(expected_events) - 1, - ): + ) + async for event in events_iterator: events.append(event) if len(events) >= len(expected_events): break @@ -378,7 +379,7 @@ async def collect_events(): except asyncio.TimeoutError: pytest.fail(f"Test timed out after {timeout} seconds while waiting for events.") - _compare_events(expected_events, task_id, results) + _compare_events(expected_events, task_id, events) def _compare_events( @@ -422,46 +423,43 @@ def on_event(event: E, event_id: str | None) -> None: return future -def take_events_from_streams( - streams: list[EventStream[Any, int]], +async def take_events_from_streams( + streams: list["EventStream[Any, Any]"], cutoff_predicate: Callable[[Any], bool], -) -> Future[list[Any]]: - """Returns a collated list of futures for events in numerous event streams. - - The support for generic and algebraic types doesn't appear to extend to - taking an arbitrary list of concrete types with single but differing - generic arguments while also maintaining the generality of the argument - types. - - The type for streams will be any combination of event streams each of a - given event type, where the event type is generic: - - List[ - Union[ - EventStream[WorkerEvent, int], - EventStream[DataEvent, int], - EventStream[ProgressEvent, int] - ] - ] +) -> AsyncIterator[Any]: + """Returns an async generator that yields events from multiple event streams.""" - """ - events: list[Any] = [] - future: Future[list[Any]] = Future() + event_queue = asyncio.Queue() # Queue to store events from the streams + cutoff_reached = asyncio.Event() # Event to signal when to stop listening def on_event(event: Any, event_id: str | None) -> None: - print(event) - events.append(event) + """Callback for events.""" + event_queue.put_nowait(event) # Add the event to the async queue if cutoff_predicate(event): - future.set_result(events) + cutoff_reached.set() # Signal the cutoff event + + subscriptions = [] + # Subscribe to all the event streams for stream in streams: sub = stream.subscribe(on_event) + subscriptions.append((stream, sub)) - def callback(unused: Future[list[Any]], stream=stream, sub=sub): - stream.unsubscribe(sub) + async def event_producer(): + """Asynchronously yield events from the queue.""" + while not cutoff_reached.is_set(): + event = await event_queue.get() # Wait for the next event + yield event - future.add_done_callback(callback) - return future + try: + # Yield events using the event_producer async generator + async for event in event_producer(): + yield event + + finally: + # Ensure we unsubscribe from all streams once done + for stream, sub in subscriptions: + stream.unsubscribe(sub) @pytest.mark.parametrize(