From 7c582ca21ec743218cee791f3de8a77d4b0be0c2 Mon Sep 17 00:00:00 2001 From: Joseph Perez Date: Sat, 25 Nov 2023 07:30:41 +0100 Subject: [PATCH] feat: add `coroutine::CancelHandle` --- newsfragments/3599.added.md | 1 + pyo3-macros-backend/src/attributes.rs | 3 +- pyo3-macros-backend/src/method.rs | 50 ++++++++++++- pyo3-macros-backend/src/params.rs | 4 + pyo3-macros-backend/src/pyfunction.rs | 22 +++++- .../src/pyfunction/signature.rs | 14 +++- src/coroutine.rs | 17 ++++- src/coroutine/cancel.rs | 74 +++++++++++++++++++ src/impl_/coroutine.rs | 4 +- src/instance.rs | 2 +- tests/test_coroutine.rs | 32 +++++++- tests/ui/invalid_pyfunction_signatures.stderr | 2 +- 12 files changed, 207 insertions(+), 18 deletions(-) create mode 100644 newsfragments/3599.added.md create mode 100644 src/coroutine/cancel.rs diff --git a/newsfragments/3599.added.md b/newsfragments/3599.added.md new file mode 100644 index 00000000000..36078fbcdb6 --- /dev/null +++ b/newsfragments/3599.added.md @@ -0,0 +1 @@ +Add `coroutine::CancelHandle` to catch coroutine cancellation \ No newline at end of file diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index e5e91cef42c..1209b1bb755 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -9,9 +9,10 @@ use syn::{ }; pub mod kw { - syn::custom_keyword!(args); syn::custom_keyword!(annotation); + syn::custom_keyword!(args); syn::custom_keyword!(attribute); + syn::custom_keyword!(cancel_handle); syn::custom_keyword!(dict); syn::custom_keyword!(extends); syn::custom_keyword!(freelist); diff --git a/pyo3-macros-backend/src/method.rs b/pyo3-macros-backend/src/method.rs index d8d901bb59f..9a1919d2f51 100644 --- a/pyo3-macros-backend/src/method.rs +++ b/pyo3-macros-backend/src/method.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue}; use crate::deprecations::{Deprecation, Deprecations}; use crate::params::impl_arg_params; -use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes}; +use crate::pyfunction::{CancelHandleAttribute, FunctionSignature, PyFunctionArgPyO3Attributes}; use crate::pyfunction::{PyFunctionOptions, SignatureAttribute}; use crate::quotes; use crate::utils::{self, PythonDoc}; @@ -12,7 +12,7 @@ use quote::ToTokens; use quote::{quote, quote_spanned}; use syn::ext::IdentExt; use syn::spanned::Spanned; -use syn::{Ident, Result}; +use syn::{Ident, Result, Token}; #[derive(Clone, Debug)] pub struct FnArg<'a> { @@ -24,6 +24,7 @@ pub struct FnArg<'a> { pub attrs: PyFunctionArgPyO3Attributes, pub is_varargs: bool, pub is_kwargs: bool, + pub is_cancel_handle: bool, } impl<'a> FnArg<'a> { @@ -53,12 +54,30 @@ impl<'a> FnArg<'a> { attrs: arg_attrs, is_varargs: false, is_kwargs: false, + is_cancel_handle: false, }) } } } } +pub fn update_cancel_handle( + asyncness: Option, + arguments: &mut [FnArg<'_>], + cancel_handle: CancelHandleAttribute, +) -> Result<()> { + if asyncness.is_none() { + bail_spanned!(cancel_handle.kw.span() => "`cancel_handle` attribute only allowed with `async fn`"); + } + for arg in arguments { + if arg.name == &cancel_handle.value.0 { + arg.is_cancel_handle = true; + return Ok(()); + } + } + bail_spanned!(cancel_handle.value.span() => "missing cancel_handle argument") +} + fn handle_argument_error(pat: &syn::Pat) -> syn::Error { let span = pat.span(); let msg = match pat { @@ -278,6 +297,7 @@ impl<'a> FnSpec<'a> { text_signature, name, signature, + cancel_handle, .. } = options; @@ -291,7 +311,7 @@ impl<'a> FnSpec<'a> { let ty = get_return_info(&sig.output); let python_name = python_name.as_ref().unwrap_or(name).unraw(); - let arguments: Vec<_> = sig + let mut arguments: Vec<_> = sig .inputs .iter_mut() .skip(if fn_type.skip_first_rust_argument_in_python_signature() { @@ -302,6 +322,10 @@ impl<'a> FnSpec<'a> { .map(FnArg::parse) .collect::>()?; + if let Some(cancel_handle) = cancel_handle { + update_cancel_handle(sig.asyncness, &mut arguments, cancel_handle)?; + } + let signature = if let Some(signature) = signature { FunctionSignature::from_arguments_and_attribute(arguments, signature)? } else { @@ -457,6 +481,16 @@ impl<'a> FnSpec<'a> { let rust_call = |args: Vec| { let mut call = quote! { function(#self_arg #(#args),*) }; if self.asyncness.is_some() { + let cancel_handle = self + .signature + .arguments + .iter() + .find(|arg| arg.is_cancel_handle); + let throw_callback = if cancel_handle.is_some() { + quote! { Some(__throw_callback) } + } else { + quote! { None } + }; let python_name = &self.python_name; let qualname_prefix = match cls { Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)), @@ -467,9 +501,17 @@ impl<'a> FnSpec<'a> { _pyo3::impl_::coroutine::new_coroutine( _pyo3::intern!(py, stringify!(#python_name)), #qualname_prefix, - async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) } + #throw_callback, + async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }, ) }}; + if cancel_handle.is_some() { + call = quote! {{ + let __cancel_handle = _pyo3::coroutine::CancelHandle::new(); + let __throw_callback = __cancel_handle.throw_callback(); + #call + }}; + } } quotes::map_result_into_ptr(quotes::ok_wrap(call)) }; diff --git a/pyo3-macros-backend/src/params.rs b/pyo3-macros-backend/src/params.rs index e511ca754ac..5e7a86ff8e2 100644 --- a/pyo3-macros-backend/src/params.rs +++ b/pyo3-macros-backend/src/params.rs @@ -155,6 +155,10 @@ fn impl_arg_param( return Ok(quote! { py }); } + if arg.is_cancel_handle { + return Ok(quote! { __cancel_handle }); + } + let name = arg.name; let name_str = name.to_string(); diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index 6f60651628a..addb170a764 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -1,7 +1,8 @@ +use crate::method::update_cancel_handle; use crate::{ attributes::{ self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute, - FromPyWithAttribute, NameAttribute, TextSignatureAttribute, + FromPyWithAttribute, KeywordAttribute, NameAttribute, NameLitStr, TextSignatureAttribute, }, deprecations::Deprecations, method::{self, CallingConvention, FnArg}, @@ -73,6 +74,7 @@ pub struct PyFunctionOptions { pub signature: Option, pub text_signature: Option, pub krate: Option, + pub cancel_handle: Option, } impl Parse for PyFunctionOptions { @@ -81,7 +83,8 @@ impl Parse for PyFunctionOptions { while !input.is_empty() { let lookahead = input.lookahead1(); - if lookahead.peek(attributes::kw::name) + if lookahead.peek(attributes::kw::cancel_handle) + || lookahead.peek(attributes::kw::name) || lookahead.peek(attributes::kw::pass_module) || lookahead.peek(attributes::kw::signature) || lookahead.peek(attributes::kw::text_signature) @@ -103,6 +106,7 @@ impl Parse for PyFunctionOptions { } pub enum PyFunctionOption { + CancelHandle(CancelHandleAttribute), Name(NameAttribute), PassModule(attributes::kw::pass_module), Signature(SignatureAttribute), @@ -113,7 +117,9 @@ pub enum PyFunctionOption { impl Parse for PyFunctionOption { fn parse(input: ParseStream<'_>) -> Result { let lookahead = input.lookahead1(); - if lookahead.peek(attributes::kw::name) { + if lookahead.peek(attributes::kw::cancel_handle) { + input.parse().map(PyFunctionOption::CancelHandle) + } else if lookahead.peek(attributes::kw::name) { input.parse().map(PyFunctionOption::Name) } else if lookahead.peek(attributes::kw::pass_module) { input.parse().map(PyFunctionOption::PassModule) @@ -153,6 +159,7 @@ impl PyFunctionOptions { } for attr in attrs { match attr { + PyFunctionOption::CancelHandle(cancel_handle) => set_option!(cancel_handle), PyFunctionOption::Name(name) => set_option!(name), PyFunctionOption::PassModule(pass_module) => set_option!(pass_module), PyFunctionOption::Signature(signature) => set_option!(signature), @@ -164,6 +171,8 @@ impl PyFunctionOptions { } } +pub type CancelHandleAttribute = KeywordAttribute; + pub fn build_py_function( ast: &mut syn::ItemFn, mut options: PyFunctionOptions, @@ -180,6 +189,7 @@ pub fn impl_wrap_pyfunction( ) -> syn::Result { check_generic(&func.sig)?; let PyFunctionOptions { + cancel_handle, pass_module, name, signature, @@ -201,7 +211,7 @@ pub fn impl_wrap_pyfunction( method::FnType::FnStatic }; - let arguments = func + let mut arguments = func .sig .inputs .iter_mut() @@ -213,6 +223,10 @@ pub fn impl_wrap_pyfunction( .map(FnArg::parse) .collect::>>()?; + if let Some(cancel_handle) = cancel_handle { + update_cancel_handle(func.sig.asyncness, &mut arguments, cancel_handle)?; + } + let signature = if let Some(signature) = signature { FunctionSignature::from_arguments_and_attribute(arguments, signature)? } else { diff --git a/pyo3-macros-backend/src/pyfunction/signature.rs b/pyo3-macros-backend/src/pyfunction/signature.rs index ed3256ad461..baf01285658 100644 --- a/pyo3-macros-backend/src/pyfunction/signature.rs +++ b/pyo3-macros-backend/src/pyfunction/signature.rs @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> { // Otherwise try next argument. continue; } + if fn_arg.is_cancel_handle { + // If the user incorrectly tried to include cancel: CoroutineCancel in the + // signature, give a useful error as a hint. + ensure_spanned!( + name != fn_arg.name, + name.span() => "`cancel_handle` argument must not be part of the signature" + ); + // Otherwise try next argument. + continue; + } ensure_spanned!( name == fn_arg.name, @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> { } // Ensure no non-py arguments remain - if let Some(arg) = args_iter.find(|arg| !arg.py) { + if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) { bail_spanned!( attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name) ); @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> { let mut python_signature = PythonSignature::default(); for arg in &arguments { // Python<'_> arguments don't show in Python signature - if arg.py { + if arg.py || arg.is_cancel_handle { continue; } diff --git a/src/coroutine.rs b/src/coroutine.rs index c1a73938eeb..f44f8540156 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -21,8 +21,12 @@ use crate::{ IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python, }; +pub(crate) mod cancel; mod waker; +use crate::coroutine::cancel::ThrowCallback; +pub use cancel::CancelHandle; + const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine"; type FutureOutput = Result, Box>; @@ -32,6 +36,7 @@ type FutureOutput = Result, Box>; pub struct Coroutine { name: Option>, qualname_prefix: Option<&'static str>, + throw_callback: Option, future: Option + Send>>>, waker: Option>, } @@ -46,6 +51,7 @@ impl Coroutine { pub(crate) fn new( name: Option>, qualname_prefix: Option<&'static str>, + throw_callback: Option, future: F, ) -> Self where @@ -61,6 +67,7 @@ impl Coroutine { Self { name, qualname_prefix, + throw_callback, future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())), waker: None, } @@ -77,9 +84,13 @@ impl Coroutine { None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)), }; // reraise thrown exception it - if let Some(exc) = throw { - self.close(); - return Err(PyErr::from_value(exc.as_ref(py))); + match (throw, &self.throw_callback) { + (Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)), + (Some(exc), None) => { + self.close(); + return Err(PyErr::from_value(exc.as_ref(py))); + } + _ => {} } // create a new waker, or try to reset it in place if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) { diff --git a/src/coroutine/cancel.rs b/src/coroutine/cancel.rs new file mode 100644 index 00000000000..756678cd61e --- /dev/null +++ b/src/coroutine/cancel.rs @@ -0,0 +1,74 @@ +use crate::{ffi, Py, PyAny, PyObject}; +use futures_util::future::poll_fn; +use futures_util::task::AtomicWaker; +use std::ptr; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicPtr, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +#[derive(Debug, Default)] +struct Inner { + exception: AtomicPtr, + waker: AtomicWaker, +} + +/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine). +/// +/// Only the last exception thrown can be retrieved. +#[derive(Debug, Default)] +pub struct CancelHandle(Arc); + +impl CancelHandle { + /// Create a new `CoroutineCancel`. + pub fn new() -> Self { + Default::default() + } + + /// Returns whether the associated coroutine has been cancelled. + pub fn is_cancelled(&self) -> bool { + !self.0.exception.load(Ordering::Relaxed).is_null() + } + + /// Poll to retrieve the exception thrown in the associated coroutine. + pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll { + // SAFETY: only valid owned pointer are set in `ThrowCallback::throw` + let take = || unsafe { + // pointer cannot be null because it is checked the line before, + // and the swap is protected by `&mut self` + Py::from_non_null( + NonNull::new(self.0.exception.swap(ptr::null_mut(), Ordering::Relaxed)).unwrap(), + ) + }; + if self.is_cancelled() { + return Poll::Ready(take()); + } + self.0.waker.register(cx.waker()); + if self.is_cancelled() { + return Poll::Ready(take()); + } + Poll::Pending + } + + /// Retrieve the exception thrown in the associated coroutine. + pub async fn cancelled(&mut self) -> PyObject { + poll_fn(|cx| self.poll_cancelled(cx)).await + } + + #[doc(hidden)] + pub fn throw_callback(&self) -> ThrowCallback { + ThrowCallback(self.0.clone()) + } +} + +#[doc(hidden)] +pub struct ThrowCallback(Arc); + +impl ThrowCallback { + pub(super) fn throw(&self, exc: &PyAny) { + let ptr = self.0.exception.swap(exc.into_ptr(), Ordering::Relaxed); + // SAFETY: non-null pointers set in `self.0.exceptions` are valid owned pointers + drop(unsafe { PyObject::from_owned_ptr_or_opt(exc.py(), ptr) }); + self.0.waker.wake(); + } +} diff --git a/src/impl_/coroutine.rs b/src/impl_/coroutine.rs index 6f66cc480f5..c8b2cdcce49 100644 --- a/src/impl_/coroutine.rs +++ b/src/impl_/coroutine.rs @@ -1,10 +1,12 @@ use std::future::Future; +use crate::coroutine::cancel::ThrowCallback; use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject}; pub fn new_coroutine( name: &PyString, qualname_prefix: Option<&'static str>, + throw_callback: Option, future: F, ) -> Coroutine where @@ -12,5 +14,5 @@ where T: IntoPy, E: Into, { - Coroutine::new(Some(name.into()), qualname_prefix, future) + Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future) } diff --git a/src/instance.rs b/src/instance.rs index 07a6f872bf4..4c43ef33cf1 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -1038,7 +1038,7 @@ impl Py { /// # Safety /// `ptr` must point to a Python object of type T. #[inline] - unsafe fn from_non_null(ptr: NonNull) -> Self { + pub(crate) unsafe fn from_non_null(ptr: NonNull) -> Self { Self(ptr, PhantomData) } diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 7420a0934aa..38a5f79ba33 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -3,7 +3,8 @@ use std::ops::Deref; use std::{task::Poll, thread, time::Duration}; -use futures::{channel::oneshot, future::poll_fn}; +use futures::{channel::oneshot, future::poll_fn, FutureExt}; +use pyo3::coroutine::CancelHandle; use pyo3::types::{IntoPyDict, PyType}; use pyo3::{prelude::*, py_run}; @@ -136,3 +137,32 @@ fn cancelled_coroutine() { assert_eq!(err.value(gil).get_type().name().unwrap(), "CancelledError"); }) } + +#[test] +fn coroutine_cancel_handle() { + #[pyfunction(cancel_handle = "cancel")] + async fn cancellable_sleep(seconds: f64, mut cancel: CancelHandle) -> usize { + futures::select! { + _ = sleep(seconds).fuse() => 42, + _ = cancel.cancelled().fuse() => 0, + } + } + Python::with_gil(|gil| { + let cancellable_sleep = wrap_pyfunction!(cancellable_sleep, gil).unwrap(); + let test = r#" + import asyncio; + async def main(): + task = asyncio.create_task(cancellable_sleep(1)) + await asyncio.sleep(0) + task.cancel() + return await task + assert asyncio.run(main()) == 0 + "#; + let globals = gil.import("__main__").unwrap().dict(); + globals + .set_item("cancellable_sleep", cancellable_sleep) + .unwrap(); + gil.run(&pyo3::unindent::unindent(test), Some(globals), None) + .unwrap(); + }) +} diff --git a/tests/ui/invalid_pyfunction_signatures.stderr b/tests/ui/invalid_pyfunction_signatures.stderr index dbca169d8ea..4b4488febfb 100644 --- a/tests/ui/invalid_pyfunction_signatures.stderr +++ b/tests/ui/invalid_pyfunction_signatures.stderr @@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x` 13 | #[pyo3(signature = (x))] | ^ -error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate` +error: expected one of: `cancel_handle`, `name`, `pass_module`, `signature`, `text_signature`, `crate` --> tests/ui/invalid_pyfunction_signatures.rs:18:14 | 18 | #[pyfunction(x)]