diff --git a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py index 24776aabd..dceff5a62 100644 --- a/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py +++ b/google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py @@ -80,7 +80,6 @@ class CrossSync(metaclass=MappingMeta): # provide aliases for common async functions and types sleep = asyncio.sleep - wait = asyncio.wait retry_target = retries.retry_target_async retry_target_stream = retries.retry_target_stream_async Retry = retries.AsyncRetry @@ -147,6 +146,20 @@ async def gather_partials( *awaitable_list, return_exceptions=return_exceptions ) + @staticmethod + async def wait( + futures: Sequence[CrossSync.Future[T]], timeout: float | None = None + ) -> tuple[set[CrossSync.Future[T]], set[CrossSync.Future[T]]]: + """ + abstraction over asyncio.wait + + Return: + - a tuple of (done, pending) sets of futures + """ + if not futures: + return set(), set() + return await asyncio.wait(futures, timeout=timeout) + @staticmethod async def event_wait( event: CrossSync.Event, @@ -224,7 +237,6 @@ class _Sync_Impl(metaclass=MappingMeta): is_async = False sleep = time.sleep - wait = concurrent.futures.wait next = next retry_target = retries.retry_target retry_target_stream = retries.retry_target_stream @@ -279,6 +291,17 @@ def gather_partials( results_list.append(future.result()) return results_list + @staticmethod + def wait( + futures: Sequence[CrossSync._Sync_Impl.Future[T]], + timeout: float | None = None, + ) -> tuple[ + set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]] + ]: + if not futures: + return set(), set() + return concurrent.futures.wait(futures, timeout=timeout) + @staticmethod def create_task( fn: Callable[..., T], diff --git a/tests/unit/data/_sync/test_cross_sync.py b/tests/unit/data/_sync/test_cross_sync.py index 9c3022f6c..903207694 100644 --- a/tests/unit/data/_sync/test_cross_sync.py +++ b/tests/unit/data/_sync/test_cross_sync.py @@ -51,7 +51,6 @@ def cs_async(self): [ ("is_async", True, False), ("sleep", asyncio.sleep, time.sleep), - ("wait", asyncio.wait, concurrent.futures.wait), ( "retry_target", api_core.retry.retry_target_async, @@ -237,6 +236,98 @@ async def coro(i): for coro in found_args: await coro + def test_wait(self, cs_sync): + """ + Test sync version of CrossSync.wait() + + If future is complete, it should be in the first (complete) set + """ + future = concurrent.futures.Future() + future.set_result(1) + s1, s2 = cs_sync.wait([future]) + assert s1 == {future} + assert s2 == set() + + def test_wait_timeout(self, cs_sync): + """ + If timeout occurs, future should be in the second (incomplete) set + """ + future = concurrent.futures.Future() + timeout = 0.1 + start_time = time.monotonic() + s1, s2 = cs_sync.wait([future], timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + assert s1 == set() + assert s2 == {future} + + def test_wait_passthrough(self, cs_sync): + """ + sync version of CrossSync.wait() should pass through to concurrent.futures.wait() + """ + future = object() + timeout = object() + with mock.patch.object(concurrent.futures, "wait", mock.Mock()) as wait: + result = cs_sync.wait([future], timeout) + assert wait.call_count == 1 + assert wait.call_args == (([future],), {"timeout": timeout}) + assert result == wait.return_value + + def test_wait_empty_input(self, cs_sync): + """ + If no futures are provided, return empty sets + """ + s1, s2 = cs_sync.wait([]) + assert s1 == set() + assert s2 == set() + + @pytest.mark.asyncio + async def test_wait_async(self, cs_async): + """ + Test async version of CrossSync.wait() + """ + future = asyncio.Future() + future.set_result(1) + s1, s2 = await cs_async.wait([future]) + assert s1 == {future} + assert s2 == set() + + @pytest.mark.asyncio + async def test_wait_async_timeout(self, cs_async): + """ + If timeout occurs, future should be in the second (incomplete) set + """ + future = asyncio.Future() + timeout = 0.1 + start_time = time.monotonic() + s1, s2 = await cs_async.wait([future], timeout) + end_time = time.monotonic() + assert abs((end_time - start_time) - timeout) < 0.01 + assert s1 == set() + assert s2 == {future} + + @pytest.mark.asyncio + async def test_wait_async_passthrough(self, cs_async): + """ + async version of CrossSync.wait() should pass through to asyncio.wait() + """ + future = object() + timeout = object() + with mock.patch.object(asyncio, "wait", AsyncMock()) as wait: + result = await cs_async.wait([future], timeout) + assert wait.call_count == 1 + assert wait.call_args == (([future],), {"timeout": timeout}) + assert result == wait.return_value + + @pytest.mark.asyncio + async def test_wait_async_empty_input(self, cs_async): + """ + If no futures are provided, return empty sets + """ + s1, s2 = await cs_async.wait([]) + assert s1 == set() + assert s2 == set() + def test_event_wait_passthrough(self, cs_sync): """ Test sync version of CrossSync.event_wait()