Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add coroutine::CancelHandle #3599

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,27 @@ where

## Cancellation

*To be implemented*
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].

```rust
# #![allow(dead_code)]
use futures::FutureExt;
use pyo3::prelude::*;
use pyo3::coroutine::CancelHandle;

#[pyfunction]
async fn cancellable(#[pyo3(cancel_handle)] mut cancel: CancelHandle) {
futures::select! {
/* _ = ... => println!("done"), */
_ = cancel.cancelled().fuse() => println!("cancelled"),
}
}
```

## The `Coroutine` type

To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).

Each `coroutine.send` call is translated to a `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;

*The type does not yet have a public constructor until the design is finalized.*
1 change: 1 addition & 0 deletions newsfragments/3599.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `coroutine::CancelHandle` to catch coroutine cancellation
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use syn::{
pub mod kw {
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);
Expand Down
53 changes: 46 additions & 7 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct FnArg<'a> {
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
}

impl<'a> FnArg<'a> {
Expand All @@ -44,6 +45,8 @@ impl<'a> FnArg<'a> {
other => return Err(handle_argument_error(other)),
};

let is_cancel_handle = arg_attrs.cancel_handle.is_some();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave an empty line after this statement.


Ok(FnArg {
name: ident,
ty: &cap.ty,
Expand All @@ -53,6 +56,7 @@ impl<'a> FnArg<'a> {
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle,
})
}
}
Expand Down Expand Up @@ -455,9 +459,27 @@ impl<'a> FnSpec<'a> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;

let mut cancel_handle_iter = self
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
}
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
quote! { None }
};
let python_name = &self.python_name;
let qualname_prefix = match cls {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
Expand All @@ -468,9 +490,17 @@ impl<'a> FnSpec<'a> {
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
#throw_callback,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
if cancel_handle.is_some() {
call = quote! {{
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
let __throw_callback = __cancel_handle.throw_callback();
#call
}};
}
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand All @@ -483,12 +513,21 @@ impl<'a> FnSpec<'a> {

Ok(match self.convention {
CallingConvention::Noargs => {
let call = if !self.signature.arguments.is_empty() {
// Only `py` arg can be here
rust_call(vec![quote!(py)])
} else {
rust_call(vec![])
};
let args = self
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
})
.collect();
let call = rust_call(args);

quote! {
unsafe fn #ident<'py>(
Expand Down
4 changes: 4 additions & 0 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ fn impl_arg_param(
return Ok(quote! { py });
}

if arg.is_cancel_handle {
return Ok(quote! { __cancel_handle });
}

let name = arg.name;
let name_str = name.to_string();

Expand Down
22 changes: 20 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
pub from_py_with: Option<FromPyWithAttribute>,
pub cancel_handle: Option<attributes::kw::cancel_handle>,
}

enum PyFunctionArgPyO3Attribute {
FromPyWith(FromPyWithAttribute),
CancelHandle(attributes::kw::cancel_handle),
}

impl Parse for PyFunctionArgPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::from_py_with) {
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
} else {
Err(lookahead.error())
Expand All @@ -43,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
impl PyFunctionArgPyO3Attributes {
/// Parses #[pyo3(from_python_with = "func")]
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
let mut attributes = PyFunctionArgPyO3Attributes {
from_py_with: None,
cancel_handle: None,
};
take_attributes(attrs, |attr| {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for attr in pyo3_attrs {
Expand All @@ -55,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
);
attributes.from_py_with = Some(from_py_with);
}
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
ensure_spanned!(
attributes.cancel_handle.is_none(),
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
);
attributes.cancel_handle = Some(cancel_handle);
}
}
ensure_spanned!(
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
);
}
Ok(true)
} else {
Expand Down
14 changes: 12 additions & 2 deletions pyo3-macros-backend/src/pyfunction/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
// Otherwise try next argument.
continue;
}
if fn_arg.is_cancel_handle {
// If the user incorrectly tried to include cancel: CoroutineCancel in the
// signature, give a useful error as a hint.
ensure_spanned!(
name != fn_arg.name,
name.span() => "`cancel_handle` argument must not be part of the signature"
);
// Otherwise try next argument.
continue;
}

ensure_spanned!(
name == fn_arg.name,
Expand Down Expand Up @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
}

// Ensure no non-py arguments remain
if let Some(arg) = args_iter.find(|arg| !arg.py) {
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
bail_spanned!(
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
);
Expand All @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
let mut python_signature = PythonSignature::default();
for arg in &arguments {
// Python<'_> arguments don't show in Python signature
if arg.py {
if arg.py || arg.is_cancel_handle {
continue;
}

Expand Down
17 changes: 14 additions & 3 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ use crate::{
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};

pub(crate) mod cancel;
mod waker;

use crate::coroutine::cancel::ThrowCallback;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
Expand All @@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -46,6 +51,7 @@ impl Coroutine {
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Self
where
Expand All @@ -61,6 +67,7 @@ impl Coroutine {
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
Expand All @@ -77,9 +84,13 @@ impl Coroutine {
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
// reraise thrown exception it
if let Some(exc) = throw {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
}
(None, _) => {}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
Expand Down
78 changes: 78 additions & 0 deletions src/coroutine/cancel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::{PyAny, PyObject};
use parking_lot::Mutex;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};

#[derive(Debug, Default)]
struct Inner {
exception: Option<PyObject>,
waker: Option<Waker>,
}

/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
///
/// Only the last exception thrown can be retrieved.
#[derive(Debug, Default)]
pub struct CancelHandle(Arc<Mutex<Inner>>);

impl CancelHandle {
/// Create a new `CoroutineCancel`.
pub fn new() -> Self {
Default::default()
}

/// Returns whether the associated coroutine has been cancelled.
pub fn is_cancelled(&self) -> bool {
self.0.lock().exception.is_some()
}

/// Poll to retrieve the exception thrown in the associated coroutine.
pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
let mut inner = self.0.lock();
if let Some(exc) = inner.exception.take() {
return Poll::Ready(exc);
}
if let Some(ref waker) = inner.waker {
if cx.waker().will_wake(waker) {
return Poll::Pending;
}
}
inner.waker = Some(cx.waker().clone());
Poll::Pending
}

/// Retrieve the exception thrown in the associated coroutine.
pub async fn cancelled(&mut self) -> PyObject {
Cancelled(self).await
}

#[doc(hidden)]
pub fn throw_callback(&self) -> ThrowCallback {
ThrowCallback(self.0.clone())
}
}

// Because `poll_fn` is not available in MSRV
struct Cancelled<'a>(&'a mut CancelHandle);

impl Future for Cancelled<'_> {
type Output = PyObject;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.poll_cancelled(cx)
}
}

#[doc(hidden)]
pub struct ThrowCallback(Arc<Mutex<Inner>>);

impl ThrowCallback {
pub(super) fn throw(&self, exc: &PyAny) {
let mut inner = self.0.lock();
inner.exception = Some(exc.into());
if let Some(waker) = inner.waker.take() {
waker.wake();
}
}
}
4 changes: 3 additions & 1 deletion src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::future::Future;

use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};

pub fn new_coroutine<F, T, E>(
name: &PyString,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Coroutine
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
{
Coroutine::new(Some(name.into()), qualname_prefix, future)
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}
Loading