1
1
import asyncio
2
+ import concurrent .futures
2
3
import contextvars
3
4
import functools
4
5
import sys
5
- from typing import Any
6
+ import types
7
+ from collections .abc import Awaitable , Callable , Coroutine
8
+ from typing import Any , Generic , Protocol , TypeVar , Union
6
9
7
10
import sniffio
8
11
import trio .lowlevel
11
14
from . import _asyncio
12
15
from ._context import restore_context as _restore_context
13
16
17
+ _R = TypeVar ("_R" )
18
+
19
+ Coro = Coroutine [Any , Any , _R ]
20
+
21
+ Loop = Union [asyncio .AbstractEventLoop , trio .lowlevel .TrioToken ]
22
+ TaskContext = list [Any ]
23
+
14
24
15
25
class TrioThreadCancelled (BaseException ):
16
26
pass
17
27
18
28
19
- def get_running_loop ():
29
+ def get_running_loop () -> Loop :
30
+
20
31
try :
21
32
asynclib = sniffio .current_async_library ()
22
33
except sniffio .AsyncLibraryNotFoundError :
@@ -25,16 +36,16 @@ def get_running_loop():
25
36
if asynclib == "asyncio" :
26
37
return asyncio .get_running_loop ()
27
38
if asynclib == "trio" :
28
- return trio .lowlevel .current_token ()
39
+ return trio .lowlevel .current_trio_token ()
29
40
raise RuntimeError (f"unsupported library { asynclib } " )
30
41
31
42
32
43
@trio .lowlevel .disable_ki_protection
33
- async def wrap_awaitable (awaitable ) :
44
+ async def wrap_awaitable (awaitable : Awaitable [ _R ]) -> _R :
34
45
return await awaitable
35
46
36
47
37
- def create_task_threadsafe (loop , awaitable ) :
48
+ def create_task_threadsafe (loop : Loop , awaitable : Coro [ _R ]) -> None :
38
49
if isinstance (loop , trio .lowlevel .TrioToken ):
39
50
try :
40
51
loop .run_sync_soon (
@@ -44,15 +55,40 @@ def create_task_threadsafe(loop, awaitable):
44
55
)
45
56
except trio .RunFinishedError :
46
57
raise RuntimeError ("trio loop no-longer running" )
58
+ return
59
+
60
+ _asyncio .create_task_threadsafe (loop , awaitable )
47
61
48
- return _asyncio .create_task_threadsafe (loop , awaitable )
49
62
63
+ ExcInfo = Union [
64
+ tuple [type [BaseException ], BaseException , types .TracebackType ],
65
+ tuple [None , None , None ],
66
+ ]
50
67
51
- async def run_in_executor (* , loop , executor , thread_handler , child ):
68
+
69
+ class ThreadHandlerType (Protocol , Generic [_R ]):
70
+ def __call__ (
71
+ self ,
72
+ loop : Loop ,
73
+ exc_info : ExcInfo ,
74
+ task_context : TaskContext ,
75
+ func : Callable [[Callable [[], _R ]], _R ],
76
+ child : Callable [[], _R ],
77
+ ) -> _R :
78
+ ...
79
+
80
+
81
+ async def run_in_executor (
82
+ * ,
83
+ loop : Loop ,
84
+ executor : concurrent .futures .ThreadPoolExecutor ,
85
+ thread_handler : ThreadHandlerType [_R ],
86
+ child : Callable [[], _R ],
87
+ ) -> _R :
52
88
if isinstance (loop , trio .lowlevel .TrioToken ):
53
89
context = contextvars .copy_context ()
54
90
func = context .run
55
- task_context : list [ asyncio . Task [ Any ]] = []
91
+ task_context : TaskContext = []
56
92
57
93
# Run the code in the right thread
58
94
full_func = functools .partial (
@@ -66,7 +102,7 @@ async def run_in_executor(*, loop, executor, thread_handler, child):
66
102
try :
67
103
if executor is None :
68
104
69
- async def handle_cancel ():
105
+ async def handle_cancel () -> None :
70
106
try :
71
107
await trio .sleep_forever ()
72
108
except trio .Cancelled :
@@ -84,16 +120,17 @@ async def handle_cancel():
84
120
pass
85
121
finally :
86
122
nursery .cancel_scope .cancel ()
123
+ assert False
87
124
else :
88
125
event = trio .Event ()
89
126
90
- def callback (fut ) :
127
+ def callback (fut : object ) -> None :
91
128
loop .run_sync_soon (event .set )
92
129
93
130
fut = executor .submit (full_func )
94
131
fut .add_done_callback (callback )
95
132
96
- async def handle_cancel_fut ():
133
+ async def handle_cancel_fut () -> None :
97
134
try :
98
135
await trio .sleep_forever ()
99
136
except trio .Cancelled :
@@ -111,15 +148,19 @@ async def handle_cancel_fut():
111
148
return fut .result ()
112
149
except TrioThreadCancelled :
113
150
pass
151
+ assert False
114
152
finally :
115
153
_restore_context (context )
116
154
117
- return await _asyncio .run_in_executor (
118
- loop = loop , executor = executor , thread_handler = thread_handler , func = func
119
- )
155
+ else :
156
+ return await _asyncio .run_in_executor (
157
+ loop = loop , executor = executor , thread_handler = thread_handler , child = child
158
+ )
120
159
121
160
122
- async def wrap_task_context (loop , task_context , awaitable ):
161
+ async def wrap_task_context (
162
+ loop : Loop , task_context : Union [TaskContext , None ], awaitable : Awaitable [_R ]
163
+ ) -> _R :
123
164
if task_context is None :
124
165
return await awaitable
125
166
@@ -130,7 +171,6 @@ async def wrap_task_context(loop, task_context, awaitable):
130
171
return await awaitable
131
172
finally :
132
173
task_context .remove (scope )
133
- if scope .cancelled_caught :
134
- raise TrioThreadCancelled
174
+ raise TrioThreadCancelled
135
175
136
176
return await _asyncio .wrap_task_context (loop , task_context , awaitable )
0 commit comments