1
1
//! Python coroutine implementation, used notably when wrapping `async fn`
2
2
//! with `#[pyfunction]`/`#[pymethods]`.
3
+ use crate :: coroutine:: waker:: AsyncioWaker ;
3
4
use crate :: exceptions:: { PyRuntimeError , PyStopIteration } ;
4
5
use crate :: pyclass:: IterNextOutput ;
5
- use crate :: sync:: GILOnceCell ;
6
- use crate :: types:: { PyCFunction , PyIterator } ;
7
- use crate :: { intern, wrap_pyfunction, IntoPy , Py , PyAny , PyErr , PyObject , PyResult , Python } ;
8
- use pyo3_macros:: { pyclass, pyfunction, pymethods} ;
6
+ use crate :: types:: PyIterator ;
7
+ use crate :: { IntoPy , Py , PyAny , PyErr , PyObject , PyResult , Python } ;
8
+ use pyo3_macros:: { pyclass, pymethods} ;
9
9
use std:: future:: Future ;
10
10
use std:: pin:: Pin ;
11
11
use std:: sync:: Arc ;
12
12
use std:: task:: { Context , Poll } ;
13
13
14
+ mod cancel;
15
+ mod waker;
16
+
17
+ pub use crate :: coroutine:: cancel:: { CancelHandle , CoroutineCancel } ;
18
+
14
19
const COROUTINE_REUSED_ERROR : & str = "cannot reuse already awaited coroutine" ;
15
20
16
21
/// Python coroutine wrapping a [`Future`].
17
22
#[ pyclass( crate = "crate" ) ]
18
23
pub struct Coroutine {
19
24
future : Option < Pin < Box < dyn Future < Output = PyResult < PyObject > > + Send > > > ,
25
+ cancel : Option < CancelHandle > ,
20
26
waker : Option < Arc < AsyncioWaker > > ,
21
27
}
22
28
@@ -41,14 +47,40 @@ impl Coroutine {
41
47
} ;
42
48
Self {
43
49
future : Some ( Box :: pin ( wrap) ) ,
50
+ cancel : None ,
51
+ waker : None ,
52
+ }
53
+ }
54
+
55
+ /// Wrap a future into a Python coroutine.
56
+ ///
57
+ /// Coroutine `send` polls the wrapped future, ignoring the value passed
58
+ /// (should always be `None` anyway).
59
+ ///
60
+ /// Coroutine `throw` registers the exception in `cancel`, and polls the wrapped future
61
+ pub fn from_future_with_cancel < F , T , E > ( future : F , cancel : CancelHandle ) -> Self
62
+ where
63
+ F : Future < Output = Result < T , E > > + Send + ' static ,
64
+ T : IntoPy < PyObject > + Send ,
65
+ E : Send ,
66
+ PyErr : From < E > ,
67
+ {
68
+ let wrap = async move {
69
+ let obj = future. await ?;
70
+ // SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
71
+ Ok ( obj. into_py ( unsafe { Python :: assume_gil_acquired ( ) } ) )
72
+ } ;
73
+ Self {
74
+ future : Some ( Box :: pin ( wrap) ) ,
75
+ cancel : Some ( cancel) ,
44
76
waker : None ,
45
77
}
46
78
}
47
79
48
80
fn poll (
49
81
& mut self ,
50
82
py : Python < ' _ > ,
51
- throw : Option < & PyAny > ,
83
+ throw : Option < PyObject > ,
52
84
) -> PyResult < IterNextOutput < PyObject , PyObject > > {
53
85
// raise if the coroutine has already been run to completion
54
86
let future_rs = match self . future {
@@ -57,16 +89,20 @@ impl Coroutine {
57
89
} ;
58
90
// reraise thrown exception it
59
91
if let Some ( exc) = throw {
60
- self . close ( ) ;
61
- return Err ( PyErr :: from_value ( exc) ) ;
92
+ if let Some ( ref handle) = self . cancel {
93
+ handle. cancel ( py, exc)
94
+ } else {
95
+ self . close ( ) ;
96
+ return Err ( PyErr :: from_value ( exc. as_ref ( py) ) ) ;
97
+ }
62
98
}
63
99
// create a new waker, or try to reset it in place
64
100
if let Some ( waker) = self . waker . as_mut ( ) . and_then ( Arc :: get_mut) {
65
101
waker. reset ( ) ;
66
102
} else {
67
103
self . waker = Some ( Arc :: new ( AsyncioWaker :: new ( ) ) ) ;
68
104
}
69
- let waker = futures_task :: waker ( self . waker . clone ( ) . unwrap ( ) ) ;
105
+ let waker = futures_util :: task :: waker ( self . waker . clone ( ) . unwrap ( ) ) ;
70
106
// poll the Rust future and forward its results if ready
71
107
if let Poll :: Ready ( res) = future_rs. as_mut ( ) . poll ( & mut Context :: from_waker ( & waker) ) {
72
108
self . close ( ) ;
@@ -101,7 +137,7 @@ impl Coroutine {
101
137
iter_result ( self . poll ( py, None ) ?)
102
138
}
103
139
104
- fn throw ( & mut self , py : Python < ' _ > , exc : & PyAny ) -> PyResult < PyObject > {
140
+ fn throw ( & mut self , py : Python < ' _ > , exc : PyObject ) -> PyResult < PyObject > {
105
141
iter_result ( self . poll ( py, Some ( exc) ) ?)
106
142
}
107
143
@@ -119,93 +155,3 @@ impl Coroutine {
119
155
self . poll ( py, None )
120
156
}
121
157
}
122
-
123
- /// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`.
124
- ///
125
- /// asyncio future is let uninitialized until [`initialize_future`][1] is called.
126
- /// If [`wake`][2] is called before future initialization (during Rust future polling),
127
- /// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`)
128
- ///
129
- /// [1]: AsyncioWaker::initialize_future
130
- /// [2]: AsyncioWaker::wake
131
- struct AsyncioWaker ( GILOnceCell < Option < LoopAndFuture > > ) ;
132
-
133
- impl AsyncioWaker {
134
- fn new ( ) -> Self {
135
- Self ( GILOnceCell :: new ( ) )
136
- }
137
-
138
- fn reset ( & mut self ) {
139
- self . 0 . take ( ) ;
140
- }
141
-
142
- fn initialize_future < ' a > ( & ' a self , py : Python < ' a > ) -> PyResult < Option < & ' a PyAny > > {
143
- let init = || LoopAndFuture :: new ( py) . map ( Some ) ;
144
- let loop_and_future = self . 0 . get_or_try_init ( py, init) ?. as_ref ( ) ;
145
- Ok ( loop_and_future. map ( |LoopAndFuture { future, .. } | future. as_ref ( py) ) )
146
- }
147
- }
148
-
149
- impl futures_task:: ArcWake for AsyncioWaker {
150
- fn wake_by_ref ( arc_self : & Arc < Self > ) {
151
- Python :: with_gil ( |gil| {
152
- if let Some ( loop_and_future) = arc_self. 0 . get_or_init ( gil, || None ) {
153
- loop_and_future
154
- . set_result ( gil)
155
- . expect ( "unexpected error in coroutine waker" ) ;
156
- }
157
- } ) ;
158
- }
159
- }
160
-
161
- struct LoopAndFuture {
162
- event_loop : PyObject ,
163
- future : PyObject ,
164
- }
165
-
166
- impl LoopAndFuture {
167
- fn new ( py : Python < ' _ > ) -> PyResult < Self > {
168
- static GET_RUNNING_LOOP : GILOnceCell < PyObject > = GILOnceCell :: new ( ) ;
169
- let import = || -> PyResult < _ > {
170
- let module = py. import ( "asyncio" ) ?;
171
- Ok ( module. getattr ( "get_running_loop" ) ?. into ( ) )
172
- } ;
173
- let event_loop = GET_RUNNING_LOOP . get_or_try_init ( py, import) ?. call0 ( py) ?;
174
- let future = event_loop. call_method0 ( py, "create_future" ) ?;
175
- Ok ( Self { event_loop, future } )
176
- }
177
-
178
- fn set_result ( & self , py : Python < ' _ > ) -> PyResult < ( ) > {
179
- static RELEASE_WAITER : GILOnceCell < Py < PyCFunction > > = GILOnceCell :: new ( ) ;
180
- let release_waiter = RELEASE_WAITER
181
- . get_or_try_init ( py, || wrap_pyfunction ! ( release_waiter, py) . map ( Into :: into) ) ?;
182
- // `Future.set_result` must be called in event loop thread,
183
- // so it requires `call_soon_threadsafe`
184
- let call_soon_threadsafe = self . event_loop . call_method1 (
185
- py,
186
- intern ! ( py, "call_soon_threadsafe" ) ,
187
- ( release_waiter, self . future . as_ref ( py) ) ,
188
- ) ;
189
- if let Err ( err) = call_soon_threadsafe {
190
- // `call_soon_threadsafe` will raise if the event loop is closed;
191
- // instead of catching an unspecific `RuntimeError`, check directly if it's closed.
192
- let is_closed = self . event_loop . call_method0 ( py, "is_closed" ) ?;
193
- if !is_closed. extract ( py) ? {
194
- return Err ( err) ;
195
- }
196
- }
197
- Ok ( ( ) )
198
- }
199
- }
200
-
201
- /// Call `future.set_result` if the future is not done.
202
- ///
203
- /// Future can be cancelled by the event loop before being waken.
204
- /// See https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5
205
- #[ pyfunction( crate = "crate" ) ]
206
- fn release_waiter ( future : & PyAny ) -> PyResult < ( ) > {
207
- if !future. call_method0 ( "done" ) ?. extract :: < bool > ( ) ? {
208
- future. call_method1 ( "set_result" , ( future. py ( ) . None ( ) , ) ) ?;
209
- }
210
- Ok ( ( ) )
211
- }
0 commit comments