Skip to content

Commit

Permalink
feat: add coroutine::CancelHandle
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 29, 2023
1 parent d8002c4 commit 7c582ca
Show file tree
Hide file tree
Showing 12 changed files with 207 additions and 18 deletions.
1 change: 1 addition & 0 deletions newsfragments/3599.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `coroutine::CancelHandle` to catch coroutine cancellation
3 changes: 2 additions & 1 deletion pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(args);
syn::custom_keyword!(annotation);
syn::custom_keyword!(args);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
syn::custom_keyword!(dict);
syn::custom_keyword!(extends);
syn::custom_keyword!(freelist);
Expand Down
50 changes: 46 additions & 4 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Display;
use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{CancelHandleAttribute, FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
Expand All @@ -12,7 +12,7 @@ use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::{Ident, Result};
use syn::{Ident, Result, Token};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
Expand All @@ -24,6 +24,7 @@ pub struct FnArg<'a> {
pub attrs: PyFunctionArgPyO3Attributes,
pub is_varargs: bool,
pub is_kwargs: bool,
pub is_cancel_handle: bool,
}

impl<'a> FnArg<'a> {
Expand Down Expand Up @@ -53,12 +54,30 @@ impl<'a> FnArg<'a> {
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle: false,
})
}
}
}
}

pub fn update_cancel_handle(
asyncness: Option<Token![async]>,
arguments: &mut [FnArg<'_>],
cancel_handle: CancelHandleAttribute,
) -> Result<()> {
if asyncness.is_none() {
bail_spanned!(cancel_handle.kw.span() => "`cancel_handle` attribute only allowed with `async fn`");
}
for arg in arguments {
if arg.name == &cancel_handle.value.0 {
arg.is_cancel_handle = true;
return Ok(());
}
}
bail_spanned!(cancel_handle.value.span() => "missing cancel_handle argument")
}

fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
let span = pat.span();
let msg = match pat {
Expand Down Expand Up @@ -278,6 +297,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
cancel_handle,
..
} = options;

Expand All @@ -291,7 +311,7 @@ impl<'a> FnSpec<'a> {
let ty = get_return_info(&sig.output);
let python_name = python_name.as_ref().unwrap_or(name).unraw();

let arguments: Vec<_> = sig
let mut arguments: Vec<_> = sig
.inputs
.iter_mut()
.skip(if fn_type.skip_first_rust_argument_in_python_signature() {
Expand All @@ -302,6 +322,10 @@ impl<'a> FnSpec<'a> {
.map(FnArg::parse)
.collect::<Result<_>>()?;

if let Some(cancel_handle) = cancel_handle {
update_cancel_handle(sig.asyncness, &mut arguments, cancel_handle)?;
}

let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
} else {
Expand Down Expand Up @@ -457,6 +481,16 @@ impl<'a> FnSpec<'a> {
let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let cancel_handle = self
.signature
.arguments
.iter()
.find(|arg| arg.is_cancel_handle);
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
quote! { None }
};
let python_name = &self.python_name;
let qualname_prefix = match cls {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
Expand All @@ -467,9 +501,17 @@ impl<'a> FnSpec<'a> {
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) }
#throw_callback,
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
if cancel_handle.is_some() {
call = quote! {{
let __cancel_handle = _pyo3::coroutine::CancelHandle::new();
let __throw_callback = __cancel_handle.throw_callback();
#call
}};
}
}
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};
Expand Down
4 changes: 4 additions & 0 deletions pyo3-macros-backend/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ fn impl_arg_param(
return Ok(quote! { py });
}

if arg.is_cancel_handle {
return Ok(quote! { __cancel_handle });
}

let name = arg.name;
let name_str = name.to_string();

Expand Down
22 changes: 18 additions & 4 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::method::update_cancel_handle;
use crate::{
attributes::{
self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
FromPyWithAttribute, KeywordAttribute, NameAttribute, NameLitStr, TextSignatureAttribute,
},
deprecations::Deprecations,
method::{self, CallingConvention, FnArg},
Expand Down Expand Up @@ -73,6 +74,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub cancel_handle: Option<CancelHandleAttribute>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -81,7 +83,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::cancel_handle)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -103,6 +106,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
CancelHandle(CancelHandleAttribute),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -113,7 +117,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionOption::CancelHandle)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -153,6 +159,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::CancelHandle(cancel_handle) => set_option!(cancel_handle),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand All @@ -164,6 +171,8 @@ impl PyFunctionOptions {
}
}

pub type CancelHandleAttribute = KeywordAttribute<attributes::kw::cancel_handle, NameLitStr>;

pub fn build_py_function(
ast: &mut syn::ItemFn,
mut options: PyFunctionOptions,
Expand All @@ -180,6 +189,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
cancel_handle,
pass_module,
name,
signature,
Expand All @@ -201,7 +211,7 @@ pub fn impl_wrap_pyfunction(
method::FnType::FnStatic
};

let arguments = func
let mut arguments = func
.sig
.inputs
.iter_mut()
Expand All @@ -213,6 +223,10 @@ pub fn impl_wrap_pyfunction(
.map(FnArg::parse)
.collect::<syn::Result<Vec<_>>>()?;

if let Some(cancel_handle) = cancel_handle {
update_cancel_handle(func.sig.asyncness, &mut arguments, cancel_handle)?;
}

let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
} else {
Expand Down
14 changes: 12 additions & 2 deletions pyo3-macros-backend/src/pyfunction/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,16 @@ impl<'a> FunctionSignature<'a> {
// Otherwise try next argument.
continue;
}
if fn_arg.is_cancel_handle {
// If the user incorrectly tried to include cancel: CoroutineCancel in the
// signature, give a useful error as a hint.
ensure_spanned!(
name != fn_arg.name,
name.span() => "`cancel_handle` argument must not be part of the signature"
);
// Otherwise try next argument.
continue;
}

ensure_spanned!(
name == fn_arg.name,
Expand Down Expand Up @@ -411,7 +421,7 @@ impl<'a> FunctionSignature<'a> {
}

// Ensure no non-py arguments remain
if let Some(arg) = args_iter.find(|arg| !arg.py) {
if let Some(arg) = args_iter.find(|arg| !arg.py && !arg.is_cancel_handle) {
bail_spanned!(
attribute.kw.span() => format!("missing signature entry for argument `{}`", arg.name)
);
Expand All @@ -429,7 +439,7 @@ impl<'a> FunctionSignature<'a> {
let mut python_signature = PythonSignature::default();
for arg in &arguments {
// Python<'_> arguments don't show in Python signature
if arg.py {
if arg.py || arg.is_cancel_handle {
continue;
}

Expand Down
17 changes: 14 additions & 3 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ use crate::{
IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python,
};

pub(crate) mod cancel;
mod waker;

use crate::coroutine::cancel::ThrowCallback;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
Expand All @@ -32,6 +36,7 @@ type FutureOutput = Result<PyResult<PyObject>, Box<dyn Any + Send>>;
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = FutureOutput> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
Expand All @@ -46,6 +51,7 @@ impl Coroutine {
pub(crate) fn new<F, T, E>(
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Self
where
Expand All @@ -61,6 +67,7 @@ impl Coroutine {
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())),
waker: None,
}
Expand All @@ -77,9 +84,13 @@ impl Coroutine {
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
// reraise thrown exception it
if let Some(exc) = throw {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc.as_ref(py)),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.as_ref(py)));
}
_ => {}
}
// create a new waker, or try to reset it in place
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
Expand Down
Loading

0 comments on commit 7c582ca

Please sign in to comment.