Skip to content

Commit 437aef4

Browse files
author
Joseph Perez
committed
feat: handle coroutine cancellation
1 parent 8695c4f commit 437aef4

File tree

10 files changed

+300
-108
lines changed

10 files changed

+300
-108
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ unindent = { version = "0.2.1", optional = true }
3232
inventory = { version = "0.3.0", optional = true }
3333

3434
# coroutine implementation
35-
futures-task = "0.3"
35+
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
3636

3737
# crate integrations that can be added using the eponymous features
3838
anyhow = { version = "1.0", optional = true }
@@ -57,7 +57,7 @@ serde = { version = "1.0", features = ["derive"] }
5757
serde_json = "1.0.61"
5858
rayon = "1.6.1"
5959
widestring = "0.5.1"
60-
futures = "0.3.28"
60+
futures = "0.3.29"
6161

6262
[build-dependencies]
6363
pyo3-build-config = { path = "pyo3-build-config", version = "0.20.0", features = ["resolve-config"] }

pyo3-macros-backend/src/method.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ pub struct FnArg<'a> {
2121
pub optional: Option<&'a syn::Type>,
2222
pub default: Option<syn::Expr>,
2323
pub py: bool,
24+
pub coroutine_cancel: bool,
2425
pub attrs: PyFunctionArgPyO3Attributes,
2526
pub is_varargs: bool,
2627
pub is_kwargs: bool,
@@ -50,6 +51,7 @@ impl<'a> FnArg<'a> {
5051
optional: utils::option_type_argument(&cap.ty),
5152
default: None,
5253
py: utils::is_python(&cap.ty),
54+
coroutine_cancel: utils::is_coroutine_cancel(&cap.ty),
5355
attrs: arg_attrs,
5456
is_varargs: false,
5557
is_kwargs: false,
@@ -446,10 +448,27 @@ impl<'a> FnSpec<'a> {
446448
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
447449
let func_name = &self.name;
448450

451+
let coroutine_cancel = self
452+
.signature
453+
.arguments
454+
.iter()
455+
.find(|arg| arg.coroutine_cancel);
456+
if let (None, Some(arg)) = (&self.asyncness, coroutine_cancel) {
457+
bail_spanned!(arg.ty.span() => "`CoroutineCancel` argument only allowed with `async fn`");
458+
}
459+
449460
let rust_call = |args: Vec<TokenStream>| {
450461
let mut call = quote! { function(#self_arg #(#args),*) };
451462
if self.asyncness.is_some() {
452-
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
463+
call = if coroutine_cancel.is_some() {
464+
quote! {{
465+
let __coroutine_cancel = _pyo3::coroutine::CoroutineCancel::new();
466+
let __cancel_handle = __coroutine_cancel.handle();
467+
_pyo3::impl_::coroutine::wrap_future({ #call }, Some(__cancel_handle))
468+
}}
469+
} else {
470+
quote! { _pyo3::impl_::coroutine::wrap_future(#call, None) }
471+
};
453472
}
454473
quotes::map_result_into_ptr(quotes::ok_wrap(call))
455474
};

pyo3-macros-backend/src/params.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ fn impl_arg_param(
155155
return Ok(quote! { py });
156156
}
157157

158+
if arg.coroutine_cancel {
159+
return Ok(quote! { __coroutine_cancel });
160+
}
161+
158162
let name = arg.name;
159163
let name_str = name.to_string();
160164

pyo3-macros-backend/src/pyfunction/signature.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
361361
// Otherwise try next argument.
362362
continue;
363363
}
364+
if fn_arg.coroutine_cancel {
365+
// If the user incorrectly tried to include cancel: CoroutineCancel in the
366+
// signature, give a useful error as a hint.
367+
ensure_spanned!(
368+
name != fn_arg.name,
369+
name.span() => "arguments of type `CoroutineCancel` must not be part of the signature"
370+
);
371+
// Otherwise try next argument.
372+
continue;
373+
}
364374

365375
ensure_spanned!(
366376
name == fn_arg.name,
@@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
411421
}
412422

413423
// Ensure no non-py arguments remain
414-
if let Some(arg) = args_iter.find(|arg| !arg.py) {
424+
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.coroutine_cancel) {
415425
bail_spanned!(
416426
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
417427
);
@@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
429439
let mut python_signature = PythonSignature::default();
430440
for arg in &arguments {
431441
// Python<'_> arguments don't show in Python signature
432-
if arg.py {
442+
if arg.py || arg.coroutine_cancel {
433443
continue;
434444
}
435445

pyo3-macros-backend/src/utils.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,19 @@ pub fn is_python(ty: &syn::Type) -> bool {
4141
}
4242
}
4343

44+
/// Check if the given type `ty` is `pyo3::coroutine::CoroutineCancel`.
45+
pub fn is_coroutine_cancel(ty: &syn::Type) -> bool {
46+
match unwrap_ty_group(ty) {
47+
syn::Type::Path(typath) => typath
48+
.path
49+
.segments
50+
.last()
51+
.map(|seg| seg.ident == "CoroutineCancel")
52+
.unwrap_or(false),
53+
_ => false,
54+
}
55+
}
56+
4457
/// If `ty` is `Option<T>`, return `Some(T)`, else `None`.
4558
pub fn option_type_argument(ty: &syn::Type) -> Option<&syn::Type> {
4659
if let syn::Type::Path(syn::TypePath { path, .. }) = ty {

src/coroutine.rs

Lines changed: 45 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
//! Python coroutine implementation, used notably when wrapping `async fn`
22
//! with `#[pyfunction]`/`#[pymethods]`.
3+
use crate::coroutine::waker::AsyncioWaker;
34
use crate::exceptions::{PyRuntimeError, PyStopIteration};
45
use crate::pyclass::IterNextOutput;
5-
use crate::sync::GILOnceCell;
6-
use crate::types::{PyCFunction, PyIterator};
7-
use crate::{intern, wrap_pyfunction, IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
8-
use pyo3_macros::{pyclass, pyfunction, pymethods};
6+
use crate::types::PyIterator;
7+
use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
8+
use pyo3_macros::{pyclass, pymethods};
99
use std::future::Future;
1010
use std::pin::Pin;
1111
use std::sync::Arc;
1212
use std::task::{Context, Poll};
1313

14+
mod cancel;
15+
mod waker;
16+
17+
pub use crate::coroutine::cancel::{CancelHandle, CoroutineCancel};
18+
1419
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
1520

1621
/// Python coroutine wrapping a [`Future`].
1722
#[pyclass(crate = "crate")]
1823
pub struct Coroutine {
1924
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
25+
cancel: Option<CancelHandle>,
2026
waker: Option<Arc<AsyncioWaker>>,
2127
}
2228

@@ -41,14 +47,40 @@ impl Coroutine {
4147
};
4248
Self {
4349
future: Some(Box::pin(wrap)),
50+
cancel: None,
51+
waker: None,
52+
}
53+
}
54+
55+
/// Wrap a future into a Python coroutine.
56+
///
57+
/// Coroutine `send` polls the wrapped future, ignoring the value passed
58+
/// (should always be `None` anyway).
59+
///
60+
/// Coroutine `throw` registers the exception in `cancel`, and polls the wrapped future
61+
pub fn from_future_with_cancel<F, T, E>(future: F, cancel: CancelHandle) -> Self
62+
where
63+
F: Future<Output = Result<T, E>> + Send + 'static,
64+
T: IntoPy<PyObject> + Send,
65+
E: Send,
66+
PyErr: From<E>,
67+
{
68+
let wrap = async move {
69+
let obj = future.await?;
70+
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
71+
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
72+
};
73+
Self {
74+
future: Some(Box::pin(wrap)),
75+
cancel: Some(cancel),
4476
waker: None,
4577
}
4678
}
4779

4880
fn poll(
4981
&mut self,
5082
py: Python<'_>,
51-
throw: Option<&PyAny>,
83+
throw: Option<PyObject>,
5284
) -> PyResult<IterNextOutput<PyObject, PyObject>> {
5385
// raise if the coroutine has already been run to completion
5486
let future_rs = match self.future {
@@ -57,16 +89,20 @@ impl Coroutine {
5789
};
5890
// reraise thrown exception it
5991
if let Some(exc) = throw {
60-
self.close();
61-
return Err(PyErr::from_value(exc));
92+
if let Some(ref handle) = self.cancel {
93+
handle.cancel(py, exc)
94+
} else {
95+
self.close();
96+
return Err(PyErr::from_value(exc.as_ref(py)));
97+
}
6298
}
6399
// create a new waker, or try to reset it in place
64100
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
65101
waker.reset();
66102
} else {
67103
self.waker = Some(Arc::new(AsyncioWaker::new()));
68104
}
69-
let waker = futures_task::waker(self.waker.clone().unwrap());
105+
let waker = futures_util::task::waker(self.waker.clone().unwrap());
70106
// poll the Rust future and forward its results if ready
71107
if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) {
72108
self.close();
@@ -101,7 +137,7 @@ impl Coroutine {
101137
iter_result(self.poll(py, None)?)
102138
}
103139

104-
fn throw(&mut self, py: Python<'_>, exc: &PyAny) -> PyResult<PyObject> {
140+
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
105141
iter_result(self.poll(py, Some(exc))?)
106142
}
107143

@@ -119,93 +155,3 @@ impl Coroutine {
119155
self.poll(py, None)
120156
}
121157
}
122-
123-
/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`.
124-
///
125-
/// asyncio future is let uninitialized until [`initialize_future`][1] is called.
126-
/// If [`wake`][2] is called before future initialization (during Rust future polling),
127-
/// [`initialize_future`][1] will return `None` (it is roughly equivalent to `asyncio.sleep(0)`)
128-
///
129-
/// [1]: AsyncioWaker::initialize_future
130-
/// [2]: AsyncioWaker::wake
131-
struct AsyncioWaker(GILOnceCell<Option<LoopAndFuture>>);
132-
133-
impl AsyncioWaker {
134-
fn new() -> Self {
135-
Self(GILOnceCell::new())
136-
}
137-
138-
fn reset(&mut self) {
139-
self.0.take();
140-
}
141-
142-
fn initialize_future<'a>(&'a self, py: Python<'a>) -> PyResult<Option<&'a PyAny>> {
143-
let init = || LoopAndFuture::new(py).map(Some);
144-
let loop_and_future = self.0.get_or_try_init(py, init)?.as_ref();
145-
Ok(loop_and_future.map(|LoopAndFuture { future, .. }| future.as_ref(py)))
146-
}
147-
}
148-
149-
impl futures_task::ArcWake for AsyncioWaker {
150-
fn wake_by_ref(arc_self: &Arc<Self>) {
151-
Python::with_gil(|gil| {
152-
if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) {
153-
loop_and_future
154-
.set_result(gil)
155-
.expect("unexpected error in coroutine waker");
156-
}
157-
});
158-
}
159-
}
160-
161-
struct LoopAndFuture {
162-
event_loop: PyObject,
163-
future: PyObject,
164-
}
165-
166-
impl LoopAndFuture {
167-
fn new(py: Python<'_>) -> PyResult<Self> {
168-
static GET_RUNNING_LOOP: GILOnceCell<PyObject> = GILOnceCell::new();
169-
let import = || -> PyResult<_> {
170-
let module = py.import("asyncio")?;
171-
Ok(module.getattr("get_running_loop")?.into())
172-
};
173-
let event_loop = GET_RUNNING_LOOP.get_or_try_init(py, import)?.call0(py)?;
174-
let future = event_loop.call_method0(py, "create_future")?;
175-
Ok(Self { event_loop, future })
176-
}
177-
178-
fn set_result(&self, py: Python<'_>) -> PyResult<()> {
179-
static RELEASE_WAITER: GILOnceCell<Py<PyCFunction>> = GILOnceCell::new();
180-
let release_waiter = RELEASE_WAITER
181-
.get_or_try_init(py, || wrap_pyfunction!(release_waiter, py).map(Into::into))?;
182-
// `Future.set_result` must be called in event loop thread,
183-
// so it requires `call_soon_threadsafe`
184-
let call_soon_threadsafe = self.event_loop.call_method1(
185-
py,
186-
intern!(py, "call_soon_threadsafe"),
187-
(release_waiter, self.future.as_ref(py)),
188-
);
189-
if let Err(err) = call_soon_threadsafe {
190-
// `call_soon_threadsafe` will raise if the event loop is closed;
191-
// instead of catching an unspecific `RuntimeError`, check directly if it's closed.
192-
let is_closed = self.event_loop.call_method0(py, "is_closed")?;
193-
if !is_closed.extract(py)? {
194-
return Err(err);
195-
}
196-
}
197-
Ok(())
198-
}
199-
}
200-
201-
/// Call `future.set_result` if the future is not done.
202-
///
203-
/// Future can be cancelled by the event loop before being waken.
204-
/// See https://github.com/python/cpython/blob/main/Lib/asyncio/tasks.py#L452C5-L452C5
205-
#[pyfunction(crate = "crate")]
206-
fn release_waiter(future: &PyAny) -> PyResult<()> {
207-
if !future.call_method0("done")?.extract::<bool>()? {
208-
future.call_method1("set_result", (future.py().None(),))?;
209-
}
210-
Ok(())
211-
}

src/coroutine/cancel.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
use crate::{ffi, Py, PyObject, Python};
2+
use futures_util::future::poll_fn;
3+
use futures_util::task::AtomicWaker;
4+
use std::ptr;
5+
use std::sync::atomic::{AtomicPtr, Ordering};
6+
use std::sync::Arc;
7+
use std::task::{Context, Poll};
8+
9+
#[derive(Debug, Default)]
10+
struct Inner {
11+
exception: AtomicPtr<ffi::PyObject>,
12+
waker: AtomicWaker,
13+
}
14+
15+
/// Helper used to wait and retrieve exception thrown in coroutine.
16+
#[derive(Debug, Default)]
17+
pub struct CoroutineCancel(Arc<Inner>);
18+
19+
impl CoroutineCancel {
20+
/// Create a new `CoroutineCancel`.
21+
pub fn new() -> Self {
22+
Default::default()
23+
}
24+
25+
/// Return an associated [`CancelHandle`].
26+
pub fn handle(&self) -> CancelHandle {
27+
CancelHandle(self.0.clone())
28+
}
29+
30+
fn take_exception(&self) -> PyObject {
31+
let ptr = self.0.exception.swap(ptr::null_mut(), Ordering::Relaxed);
32+
Python::with_gil(|gil| unsafe { Py::from_owned_ptr(gil, ptr) })
33+
}
34+
35+
/// Returns whether the associated coroutine has been cancelled.
36+
pub fn is_cancelled(&self) -> bool {
37+
!self.0.exception.load(Ordering::Relaxed).is_null()
38+
}
39+
40+
/// Poll to retrieve the exception thrown in the associated coroutine.
41+
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
42+
if self.is_cancelled() {
43+
return Poll::Ready(self.take_exception());
44+
}
45+
self.0.waker.register(cx.waker());
46+
if self.is_cancelled() {
47+
return Poll::Ready(self.take_exception());
48+
}
49+
Poll::Pending
50+
}
51+
52+
/// Retrieve the exception thrown in the associated coroutine.
53+
pub async fn cancelled(&mut self) -> PyObject {
54+
poll_fn(|cx| self.poll_cancelled(cx)).await
55+
}
56+
}
57+
58+
/// [`CoroutineCancel`] handle used in
59+
/// [`Coroutine::from_future_with_cancel`](crate::coroutine::Coroutine::from_future_with_cancel)
60+
pub struct CancelHandle(Arc<Inner>);
61+
62+
impl CancelHandle {
63+
pub(super) fn cancel(&self, py: Python<'_>, exc: PyObject) {
64+
let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed);
65+
drop(unsafe { PyObject::from_owned_ptr_or_opt(py, ptr) });
66+
self.0.waker.wake();
67+
}
68+
}

0 commit comments

Comments
 (0)