Skip to content

Commit

Permalink
implemented wait manually
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Sep 9, 2024
1 parent 11ab1d0 commit 13abfd4
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 3 deletions.
27 changes: 25 additions & 2 deletions google/cloud/bigtable/data/_sync/cross_sync/cross_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class CrossSync(metaclass=MappingMeta):

# provide aliases for common async functions and types
sleep = asyncio.sleep
wait = asyncio.wait
retry_target = retries.retry_target_async
retry_target_stream = retries.retry_target_stream_async
Retry = retries.AsyncRetry
Expand Down Expand Up @@ -147,6 +146,20 @@ async def gather_partials(
*awaitable_list, return_exceptions=return_exceptions
)

@staticmethod
async def wait(
futures: Sequence[CrossSync.Future[T]], timeout: float | None = None
) -> tuple[set[CrossSync.Future[T]], set[CrossSync.Future[T]]]:
"""
abstraction over asyncio.wait
Return:
- a tuple of (done, pending) sets of futures
"""
if not futures:
return set(), set()
return await asyncio.wait(futures, timeout=timeout)

@staticmethod
async def event_wait(
event: CrossSync.Event,
Expand Down Expand Up @@ -224,7 +237,6 @@ class _Sync_Impl(metaclass=MappingMeta):
is_async = False

sleep = time.sleep
wait = concurrent.futures.wait
next = next
retry_target = retries.retry_target
retry_target_stream = retries.retry_target_stream
Expand Down Expand Up @@ -279,6 +291,17 @@ def gather_partials(
results_list.append(future.result())
return results_list

@staticmethod
def wait(
futures: Sequence[CrossSync._Sync_Impl.Future[T]],
timeout: float | None = None,
) -> tuple[
set[CrossSync._Sync_Impl.Future[T]], set[CrossSync._Sync_Impl.Future[T]]
]:
if not futures:
return set(), set()
return concurrent.futures.wait(futures, timeout=timeout)

@staticmethod
def create_task(
fn: Callable[..., T],
Expand Down
93 changes: 92 additions & 1 deletion tests/unit/data/_sync/test_cross_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def cs_async(self):
[
("is_async", True, False),
("sleep", asyncio.sleep, time.sleep),
("wait", asyncio.wait, concurrent.futures.wait),
(
"retry_target",
api_core.retry.retry_target_async,
Expand Down Expand Up @@ -237,6 +236,98 @@ async def coro(i):
for coro in found_args:
await coro

def test_wait(self, cs_sync):
"""
Test sync version of CrossSync.wait()
If future is complete, it should be in the first (complete) set
"""
future = concurrent.futures.Future()
future.set_result(1)
s1, s2 = cs_sync.wait([future])
assert s1 == {future}
assert s2 == set()

def test_wait_timeout(self, cs_sync):
"""
If timeout occurs, future should be in the second (incomplete) set
"""
future = concurrent.futures.Future()
timeout = 0.1
start_time = time.monotonic()
s1, s2 = cs_sync.wait([future], timeout)
end_time = time.monotonic()
assert abs((end_time - start_time) - timeout) < 0.01
assert s1 == set()
assert s2 == {future}

def test_wait_passthrough(self, cs_sync):
"""
sync version of CrossSync.wait() should pass through to concurrent.futures.wait()
"""
future = object()
timeout = object()
with mock.patch.object(concurrent.futures, "wait", mock.Mock()) as wait:
result = cs_sync.wait([future], timeout)
assert wait.call_count == 1
assert wait.call_args == (([future],), {"timeout": timeout})
assert result == wait.return_value

def test_wait_empty_input(self, cs_sync):
"""
If no futures are provided, return empty sets
"""
s1, s2 = cs_sync.wait([])
assert s1 == set()
assert s2 == set()

@pytest.mark.asyncio
async def test_wait_async(self, cs_async):
"""
Test async version of CrossSync.wait()
"""
future = asyncio.Future()
future.set_result(1)
s1, s2 = await cs_async.wait([future])
assert s1 == {future}
assert s2 == set()

@pytest.mark.asyncio
async def test_wait_async_timeout(self, cs_async):
"""
If timeout occurs, future should be in the second (incomplete) set
"""
future = asyncio.Future()
timeout = 0.1
start_time = time.monotonic()
s1, s2 = await cs_async.wait([future], timeout)
end_time = time.monotonic()
assert abs((end_time - start_time) - timeout) < 0.01
assert s1 == set()
assert s2 == {future}

@pytest.mark.asyncio
async def test_wait_async_passthrough(self, cs_async):
"""
async version of CrossSync.wait() should pass through to asyncio.wait()
"""
future = object()
timeout = object()
with mock.patch.object(asyncio, "wait", AsyncMock()) as wait:
result = await cs_async.wait([future], timeout)
assert wait.call_count == 1
assert wait.call_args == (([future],), {"timeout": timeout})
assert result == wait.return_value

@pytest.mark.asyncio
async def test_wait_async_empty_input(self, cs_async):
"""
If no futures are provided, return empty sets
"""
s1, s2 = await cs_async.wait([])
assert s1 == set()
assert s2 == set()

def test_event_wait_passthrough(self, cs_sync):
"""
Test sync version of CrossSync.event_wait()
Expand Down

0 comments on commit 13abfd4

Please sign in to comment.