Skip to content

Commit 152d990

Browse files
committed
mypy
1 parent 89d2064 commit 152d990

File tree

4 files changed

+97
-23
lines changed

4 files changed

+97
-23
lines changed

asgiref/_asyncio.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,33 @@
66
]
77

88
import asyncio
9+
import concurrent.futures
910
import contextvars
1011
import functools
1112
import sys
13+
import types
1214
from asyncio import get_running_loop
13-
from collections.abc import Callable
14-
from typing import Any, TypeVar
15+
from collections.abc import Awaitable, Callable, Coroutine
16+
from typing import Any, Generic, Protocol, TypeVar, Union
1517

1618
from ._context import restore_context as _restore_context
1719

1820
_R = TypeVar("_R")
1921

22+
Coro = Coroutine[Any, Any, _R]
2023

21-
def create_task_threadsafe(loop, awaitable) -> None:
24+
25+
def create_task_threadsafe(
26+
loop: asyncio.AbstractEventLoop, awaitable: Coro[object]
27+
) -> None:
2228
loop.call_soon_threadsafe(loop.create_task, awaitable)
2329

2430

25-
async def wrap_task_context(loop, task_context, awaitable):
31+
async def wrap_task_context(
32+
loop: asyncio.AbstractEventLoop,
33+
task_context: list[asyncio.Task[Any]],
34+
awaitable: Awaitable[_R],
35+
) -> _R:
2636
if task_context is None:
2737
return await awaitable
2838

@@ -37,8 +47,30 @@ async def wrap_task_context(loop, task_context, awaitable):
3747
task_context.remove(current_task)
3848

3949

50+
ExcInfo = Union[
51+
tuple[type[BaseException], BaseException, types.TracebackType],
52+
tuple[None, None, None],
53+
]
54+
55+
56+
class ThreadHandlerType(Protocol, Generic[_R]):
57+
def __call__(
58+
self,
59+
loop: asyncio.AbstractEventLoop,
60+
exc_info: ExcInfo,
61+
task_context: list[asyncio.Task[Any]],
62+
func: Callable[[Callable[[], _R]], _R],
63+
child: Callable[[], _R],
64+
) -> _R:
65+
...
66+
67+
4068
async def run_in_executor(
41-
*, loop, executor, thread_handler, child: Callable[[], _R]
69+
*,
70+
loop: asyncio.AbstractEventLoop,
71+
executor: concurrent.futures.ThreadPoolExecutor,
72+
thread_handler: ThreadHandlerType[_R],
73+
child: Callable[[], _R],
4274
) -> _R:
4375
context = contextvars.copy_context()
4476
func = context.run

asgiref/_trio.py

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import asyncio
2+
import concurrent.futures
23
import contextvars
34
import functools
45
import sys
5-
from typing import Any
6+
import types
7+
from collections.abc import Awaitable, Callable, Coroutine
8+
from typing import Any, Generic, Protocol, TypeVar, Union
69

710
import sniffio
811
import trio.lowlevel
@@ -11,12 +14,20 @@
1114
from . import _asyncio
1215
from ._context import restore_context as _restore_context
1316

17+
_R = TypeVar("_R")
18+
19+
Coro = Coroutine[Any, Any, _R]
20+
21+
Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken]
22+
TaskContext = list[Any]
23+
1424

1525
class TrioThreadCancelled(BaseException):
1626
pass
1727

1828

19-
def get_running_loop():
29+
def get_running_loop() -> Loop:
30+
2031
try:
2132
asynclib = sniffio.current_async_library()
2233
except sniffio.AsyncLibraryNotFoundError:
@@ -25,16 +36,16 @@ def get_running_loop():
2536
if asynclib == "asyncio":
2637
return asyncio.get_running_loop()
2738
if asynclib == "trio":
28-
return trio.lowlevel.current_token()
39+
return trio.lowlevel.current_trio_token()
2940
raise RuntimeError(f"unsupported library {asynclib}")
3041

3142

3243
@trio.lowlevel.disable_ki_protection
33-
async def wrap_awaitable(awaitable):
44+
async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R:
3445
return await awaitable
3546

3647

37-
def create_task_threadsafe(loop, awaitable):
48+
def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None:
3849
if isinstance(loop, trio.lowlevel.TrioToken):
3950
try:
4051
loop.run_sync_soon(
@@ -44,15 +55,40 @@ def create_task_threadsafe(loop, awaitable):
4455
)
4556
except trio.RunFinishedError:
4657
raise RuntimeError("trio loop no-longer running")
58+
return
59+
60+
_asyncio.create_task_threadsafe(loop, awaitable)
4761

48-
return _asyncio.create_task_threadsafe(loop, awaitable)
4962

63+
ExcInfo = Union[
64+
tuple[type[BaseException], BaseException, types.TracebackType],
65+
tuple[None, None, None],
66+
]
5067

51-
async def run_in_executor(*, loop, executor, thread_handler, child):
68+
69+
class ThreadHandlerType(Protocol, Generic[_R]):
70+
def __call__(
71+
self,
72+
loop: Loop,
73+
exc_info: ExcInfo,
74+
task_context: TaskContext,
75+
func: Callable[[Callable[[], _R]], _R],
76+
child: Callable[[], _R],
77+
) -> _R:
78+
...
79+
80+
81+
async def run_in_executor(
82+
*,
83+
loop: Loop,
84+
executor: concurrent.futures.ThreadPoolExecutor,
85+
thread_handler: ThreadHandlerType[_R],
86+
child: Callable[[], _R],
87+
) -> _R:
5288
if isinstance(loop, trio.lowlevel.TrioToken):
5389
context = contextvars.copy_context()
5490
func = context.run
55-
task_context: list[asyncio.Task[Any]] = []
91+
task_context: TaskContext = []
5692

5793
# Run the code in the right thread
5894
full_func = functools.partial(
@@ -66,7 +102,7 @@ async def run_in_executor(*, loop, executor, thread_handler, child):
66102
try:
67103
if executor is None:
68104

69-
async def handle_cancel():
105+
async def handle_cancel() -> None:
70106
try:
71107
await trio.sleep_forever()
72108
except trio.Cancelled:
@@ -84,16 +120,17 @@ async def handle_cancel():
84120
pass
85121
finally:
86122
nursery.cancel_scope.cancel()
123+
assert False
87124
else:
88125
event = trio.Event()
89126

90-
def callback(fut):
127+
def callback(fut: object) -> None:
91128
loop.run_sync_soon(event.set)
92129

93130
fut = executor.submit(full_func)
94131
fut.add_done_callback(callback)
95132

96-
async def handle_cancel_fut():
133+
async def handle_cancel_fut() -> None:
97134
try:
98135
await trio.sleep_forever()
99136
except trio.Cancelled:
@@ -111,15 +148,19 @@ async def handle_cancel_fut():
111148
return fut.result()
112149
except TrioThreadCancelled:
113150
pass
151+
assert False
114152
finally:
115153
_restore_context(context)
116154

117-
return await _asyncio.run_in_executor(
118-
loop=loop, executor=executor, thread_handler=thread_handler, func=func
119-
)
155+
else:
156+
return await _asyncio.run_in_executor(
157+
loop=loop, executor=executor, thread_handler=thread_handler, child=child
158+
)
120159

121160

122-
async def wrap_task_context(loop, task_context, awaitable):
161+
async def wrap_task_context(
162+
loop: Loop, task_context: Union[TaskContext, None], awaitable: Awaitable[_R]
163+
) -> _R:
123164
if task_context is None:
124165
return await awaitable
125166

@@ -130,7 +171,6 @@ async def wrap_task_context(loop, task_context, awaitable):
130171
return await awaitable
131172
finally:
132173
task_context.remove(scope)
133-
if scope.cancelled_caught:
134-
raise TrioThreadCancelled
174+
raise TrioThreadCancelled
135175

136176
return await _asyncio.wrap_task_context(loop, task_context, awaitable)

asgiref/sync.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def __init__(
165165
],
166166
force_new_loop: Union[LoopType, bool] = False,
167167
):
168-
if force_new_loop and not isinstance(LoopType):
168+
if force_new_loop and not isinstance(force_new_loop, LoopType):
169169
force_new_loop = LoopType.ASYNCIO
170170

171171
if not callable(awaitable) or (
@@ -319,6 +319,7 @@ async def main_wrap(
319319
if context is not None:
320320
_restore_context(context[0])
321321

322+
result: _R
322323
try:
323324
# If we have an exception, run the function inside the except block
324325
# after raising it so exc_info is correctly populated.

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ commands =
1111
mypy: mypy . {posargs}
1212
deps =
1313
setuptools
14+
mypy: trio
1415

1516
[testenv:qa]
1617
skip_install = true

0 commit comments

Comments
 (0)