Skip to content

Commit

Permalink
adjust
Browse files Browse the repository at this point in the history
  • Loading branch information
dmulcahey committed Apr 3, 2024
1 parent f55df34 commit 1a56365
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 104 deletions.
100 changes: 22 additions & 78 deletions tests/test_async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio
import functools
import time
from typing import Any
from unittest.mock import MagicMock, Mock, patch

import pytest
Expand Down Expand Up @@ -564,56 +563,6 @@ async def test_task():
assert result.result() == "Foo"


async def test_shutdown_does_not_block_on_normal_tasks(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown does not block on normal tasks."""
result = asyncio.Future()
unshielded_task = asyncio.sleep(10)

async def test_task():
try:
await unshielded_task
except asyncio.CancelledError:
result.set_result("Foo")

start = time.monotonic()
task = zha_gateway.async_create_task(test_task())
await asyncio.sleep(0)
await zha_gateway.shutdown()
await asyncio.sleep(0)
assert result.done()
assert task.done()
assert time.monotonic() - start < 0.5


async def test_shutdown_does_not_block_on_shielded_tasks(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown does not block on shielded tasks."""
result = asyncio.Future()
sleep_task = asyncio.ensure_future(asyncio.sleep(10))
shielded_task = asyncio.shield(sleep_task)

async def test_task():
try:
await shielded_task
except asyncio.CancelledError:
result.set_result("Foo")

start = time.monotonic()
task = zha_gateway.async_create_task(test_task())
await asyncio.sleep(0)
await zha_gateway.shutdown()
await asyncio.sleep(0)
assert result.done()
assert task.done()
assert time.monotonic() - start < 0.5

# Cleanup lingering task after test is done
sleep_task.cancel()


@pytest.mark.parametrize("eager_start", [True, False])
async def test_cancellable_ZHAJob(zha_gateway: Gateway, eager_start: bool) -> None:
"""Simulate a shutdown, ensure cancellable jobs are cancelled."""
Expand Down Expand Up @@ -773,6 +722,28 @@ def callback_fn():
assert it_ran is True


async def test_run_callback_threadsafe_exception(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe runs code in the event loop."""
it_ran = False

def callback_fn():
nonlocal it_ran
it_ran = True
raise ValueError("Test")

future = zha_async.run_callback_threadsafe(zha_gateway.loop, callback_fn)
assert future
assert it_ran is False

# Verify that async_block_till_done will flush
# out the callback
await zha_gateway.async_block_till_done()
assert it_ran is True

with pytest.raises(ValueError):
future.result()


async def test_callback_is_always_scheduled(zha_gateway: Gateway) -> None:
"""Test run_callback_threadsafe always calls call_soon_threadsafe before checking for shutdown."""
# We have to check the shutdown state AFTER the callback is scheduled otherwise
Expand Down Expand Up @@ -815,30 +786,3 @@ async def _eager_task():
assert events == ["eager", "normal"]
await task1
await task2


async def test_shutdown_calls_block_till_done_after_shutdown_run_callback_threadsafe(
zha_gateway: Gateway,
) -> None:
"""Ensure shutdown_run_callback_threadsafe is called before the final async_block_till_done."""
stop_calls: list[Any] = []

async def _record_block_till_done(wait_background_tasks: bool = False): # pylint: disable=unused-argument
nonlocal stop_calls
stop_calls.append("async_block_till_done")

def _record_shutdown_run_callback_threadsafe(loop):
nonlocal stop_calls
stop_calls.append(("shutdown_run_callback_threadsafe", loop))

with (
patch.object(zha_gateway, "async_block_till_done", _record_block_till_done),
patch(
"zha.async_.shutdown_run_callback_threadsafe",
_record_shutdown_run_callback_threadsafe,
),
):
await zha_gateway.shutdown()

assert stop_calls[-2] == ("shutdown_run_callback_threadsafe", zha_gateway.loop)
assert stop_calls[-1] == "async_block_till_done"
35 changes: 9 additions & 26 deletions zha/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@
from __future__ import annotations

import asyncio
from asyncio import (
AbstractEventLoop,
Future,
Semaphore,
Task,
gather,
get_running_loop,
timeout as async_timeout,
)
from asyncio import AbstractEventLoop, Future, Semaphore, Task, gather, get_running_loop
from collections.abc import Awaitable, Callable, Collection, Coroutine, Iterable
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -386,6 +378,13 @@ def __init__(self, *args, **kw_args) -> None:
async def shutdown(self) -> None:
"""Shutdown the executor."""

# Prevent run_callback_threadsafe from scheduling any additional
# callbacks in the event loop as callbacks created on the futures
# it returns will never run after the final `self.async_block_till_done`
# which will cause the futures to block forever when waiting for
# the `result()` which will cause a deadlock when shutting down the executor.
shutdown_run_callback_threadsafe(self.loop)

async def _cancel_tasks(tasks_to_cancel: Iterable) -> None:
tasks = [t for t in tasks_to_cancel if not (t.done() or t.cancelled())]
for task in tasks:
Expand All @@ -398,25 +397,9 @@ async def _cancel_tasks(tasks_to_cancel: Iterable) -> None:
await _cancel_tasks(self._tracked_completable_tasks)
await _cancel_tasks(self._device_init_tasks.values())
await _cancel_tasks(self._untracked_background_tasks)

# Prevent run_callback_threadsafe from scheduling any additional
# callbacks in the event loop as callbacks created on the futures
# it returns will never run after the final `self.async_block_till_done`
# which will cause the futures to block forever when waiting for
# the `result()` which will cause a deadlock when shutting down the executor.
shutdown_run_callback_threadsafe(self.loop)

try:
async with async_timeout(30):
await self.async_block_till_done()
except TimeoutError:
_LOGGER.warning(
"Timed out waiting for tasks to be processed, the shutdown will continue"
)
for task in self._tasks:
_LOGGER.warning("Shutdown: task still running: %s", task)
self._cancel_cancellable_timers()
self.import_executor.shutdown()
self.import_executor = None

async def async_block_till_done(self, wait_background_tasks: bool = False) -> None:
"""Block until all pending work is done."""
Expand Down

0 comments on commit 1a56365

Please sign in to comment.