Skip to content

Commit

Permalink
feat: use parameter annotation instead of function annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Nov 29, 2023
1 parent 7c582ca commit c5cb1b5
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 61 deletions.
21 changes: 19 additions & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,27 @@ where

## Cancellation

*To be implemented*
Cancellation on the Python side can be caught using [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) type, by annotating a function parameter with `#[pyo3(cancel_handle)].

```rust
# #![allow(dead_code)]
use futures::FutureExt;
use pyo3::prelude::*;
use pyo3::coroutine::CancelHandle;

#[pyfunction]
async fn cancellable(#[pyo3(cancel_handle)]mut cancel: CancelHandle) {
futures::select! {
/* _ = ... => println!("done"), */
_ = cancel.cancelled().fuse() => println!("cancelled"),
}
}
```

## The `Coroutine` type

To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine). Each `coroutine.send` call is translated to `Future::poll` call, while `coroutine.throw` call reraise the exception *(this behavior will be configurable with cancellation support)*.
To make a Rust future awaitable in Python, PyO3 defines a [`Coroutine`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.Coroutine.html) type, which implements the Python [coroutine protocol](https://docs.python.org/3/library/collections.abc.html#collections.abc.Coroutine).

Each `coroutine.send` call is translated to `Future::poll` call. If a [`CancelHandle`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html) parameter is declared, the exception passed to `coroutine.throw` call is stored in it and can be retrieved with [`CancelHandle::cancelled`]({{#PYO3_DOCS_URL}}/pyo3/coroutine/struct.CancelHandle.html#method.cancelled); otherwise, it cancels the Rust future, and the exception is reraised;

*The type does not yet have a public constructor until the design is finalized.*
70 changes: 33 additions & 37 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Display;
use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{CancelHandleAttribute, FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
Expand All @@ -12,7 +12,7 @@ use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::{Ident, Result, Token};
use syn::{Ident, Result};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
Expand Down Expand Up @@ -45,6 +45,7 @@ impl<'a> FnArg<'a> {
other => return Err(handle_argument_error(other)),
};

let is_cancel_handle = arg_attrs.cancel_handle.is_some();
Ok(FnArg {
name: ident,
ty: &cap.ty,
Expand All @@ -54,30 +55,13 @@ impl<'a> FnArg<'a> {
attrs: arg_attrs,
is_varargs: false,
is_kwargs: false,
is_cancel_handle: false,
is_cancel_handle,
})
}
}
}
}

pub fn update_cancel_handle(
asyncness: Option<Token![async]>,
arguments: &mut [FnArg<'_>],
cancel_handle: CancelHandleAttribute,
) -> Result<()> {
if asyncness.is_none() {
bail_spanned!(cancel_handle.kw.span() => "`cancel_handle` attribute only allowed with `async fn`");
}
for arg in arguments {
if arg.name == &cancel_handle.value.0 {
arg.is_cancel_handle = true;
return Ok(());
}
}
bail_spanned!(cancel_handle.value.span() => "missing cancel_handle argument")
}

fn handle_argument_error(pat: &syn::Pat) -> syn::Error {
let span = pat.span();
let msg = match pat {
Expand Down Expand Up @@ -297,7 +281,6 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
cancel_handle,
..
} = options;

Expand All @@ -311,7 +294,7 @@ impl<'a> FnSpec<'a> {
let ty = get_return_info(&sig.output);
let python_name = python_name.as_ref().unwrap_or(name).unraw();

let mut arguments: Vec<_> = sig
let arguments: Vec<_> = sig
.inputs
.iter_mut()
.skip(if fn_type.skip_first_rust_argument_in_python_signature() {
Expand All @@ -322,10 +305,6 @@ impl<'a> FnSpec<'a> {
.map(FnArg::parse)
.collect::<Result<_>>()?;

if let Some(cancel_handle) = cancel_handle {
update_cancel_handle(sig.asyncness, &mut arguments, cancel_handle)?;
}

let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
} else {
Expand Down Expand Up @@ -478,14 +457,22 @@ impl<'a> FnSpec<'a> {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise);
let func_name = &self.name;

let mut cancel_handle_iter = self
.signature
.arguments
.iter()
.filter(|arg| arg.is_cancel_handle);
let cancel_handle = cancel_handle_iter.next();
if let Some(arg) = cancel_handle {
ensure_spanned!(self.asyncness.is_some(), arg.name.span() => "`cancel_handle` attribute can only be used with `async fn`");
if let Some(arg2) = cancel_handle_iter.next() {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
}
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let cancel_handle = self
.signature
.arguments
.iter()
.find(|arg| arg.is_cancel_handle);
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
Expand Down Expand Up @@ -524,12 +511,21 @@ impl<'a> FnSpec<'a> {

Ok(match self.convention {
CallingConvention::Noargs => {
let call = if !self.signature.arguments.is_empty() {
// Only `py` arg can be here
rust_call(vec![quote!(py)])
} else {
rust_call(vec![])
};
let args = self
.signature
.arguments
.iter()
.map(|arg| {
if arg.py {
quote!(py)
} else if arg.is_cancel_handle {
quote!(__cancel_handle)
} else {
unreachable!()
}
})
.collect();
let call = rust_call(args);

quote! {
unsafe fn #ident<'py>(
Expand Down
41 changes: 23 additions & 18 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::method::update_cancel_handle;
use crate::{
attributes::{
self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
FromPyWithAttribute, KeywordAttribute, NameAttribute, NameLitStr, TextSignatureAttribute,
FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
},
deprecations::Deprecations,
method::{self, CallingConvention, FnArg},
Expand All @@ -24,16 +23,20 @@ pub use self::signature::{FunctionSignature, SignatureAttribute};
#[derive(Clone, Debug)]
pub struct PyFunctionArgPyO3Attributes {
pub from_py_with: Option<FromPyWithAttribute>,
pub cancel_handle: Option<attributes::kw::cancel_handle>,
}

enum PyFunctionArgPyO3Attribute {
FromPyWith(FromPyWithAttribute),
CancelHandle(attributes::kw::cancel_handle),
}

impl Parse for PyFunctionArgPyO3Attribute {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::from_py_with) {
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionArgPyO3Attribute::CancelHandle)
} else if lookahead.peek(attributes::kw::from_py_with) {
input.parse().map(PyFunctionArgPyO3Attribute::FromPyWith)
} else {
Err(lookahead.error())
Expand All @@ -44,7 +47,10 @@ impl Parse for PyFunctionArgPyO3Attribute {
impl PyFunctionArgPyO3Attributes {
/// Parses #[pyo3(from_python_with = "func")]
pub fn from_attrs(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Self> {
let mut attributes = PyFunctionArgPyO3Attributes { from_py_with: None };
let mut attributes = PyFunctionArgPyO3Attributes {
from_py_with: None,
cancel_handle: None,
};
take_attributes(attrs, |attr| {
if let Some(pyo3_attrs) = get_pyo3_options(attr)? {
for attr in pyo3_attrs {
Expand All @@ -56,7 +62,18 @@ impl PyFunctionArgPyO3Attributes {
);
attributes.from_py_with = Some(from_py_with);
}
PyFunctionArgPyO3Attribute::CancelHandle(cancel_handle) => {
ensure_spanned!(
attributes.cancel_handle.is_none(),
cancel_handle.span() => "`cancel_handle` may only be specified once per argument"
);
attributes.cancel_handle = Some(cancel_handle);
}
}
ensure_spanned!(
attributes.from_py_with.is_none() || attributes.cancel_handle.is_none(),
attributes.cancel_handle.unwrap().span() => "`from_py_with` and `cancel_handle` cannot be specified together"
);
}
Ok(true)
} else {
Expand All @@ -74,7 +91,6 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub cancel_handle: Option<CancelHandleAttribute>,
}

impl Parse for PyFunctionOptions {
Expand Down Expand Up @@ -106,7 +122,6 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
CancelHandle(CancelHandleAttribute),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -117,9 +132,7 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::cancel_handle) {
input.parse().map(PyFunctionOption::CancelHandle)
} else if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -159,7 +172,6 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::CancelHandle(cancel_handle) => set_option!(cancel_handle),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand All @@ -171,8 +183,6 @@ impl PyFunctionOptions {
}
}

pub type CancelHandleAttribute = KeywordAttribute<attributes::kw::cancel_handle, NameLitStr>;

pub fn build_py_function(
ast: &mut syn::ItemFn,
mut options: PyFunctionOptions,
Expand All @@ -189,7 +199,6 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
cancel_handle,
pass_module,
name,
signature,
Expand All @@ -211,7 +220,7 @@ pub fn impl_wrap_pyfunction(
method::FnType::FnStatic
};

let mut arguments = func
let arguments = func
.sig
.inputs
.iter_mut()
Expand All @@ -223,10 +232,6 @@ pub fn impl_wrap_pyfunction(
.map(FnArg::parse)
.collect::<syn::Result<Vec<_>>>()?;

if let Some(cancel_handle) = cancel_handle {
update_cancel_handle(func.sig.asyncness, &mut arguments, cancel_handle)?;
}

let signature = if let Some(signature) = signature {
FunctionSignature::from_arguments_and_attribute(arguments, signature)?
} else {
Expand Down
7 changes: 5 additions & 2 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,11 @@ fn cancelled_coroutine() {

#[test]
fn coroutine_cancel_handle() {
#[pyfunction(cancel_handle = "cancel")]
async fn cancellable_sleep(seconds: f64, mut cancel: CancelHandle) -> usize {
#[pyfunction]
async fn cancellable_sleep(
seconds: f64,
#[pyo3(cancel_handle)] mut cancel: CancelHandle,
) -> usize {
futures::select! {
_ = sleep(seconds).fuse() => 42,
_ = cancel.cancelled().fuse() => 0,
Expand Down
22 changes: 22 additions & 0 deletions tests/ui/invalid_argument_attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,26 @@ fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
#[pyfunction]
fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}

#[pyfunction]
fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}

#[pyfunction]
async fn from_py_with_value_and_cancel_handle(
#[pyo3(from_py_with = "func", cancel_handle)] _param: String,
) {
}

#[pyfunction]
async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}

#[pyfunction]
async fn cancel_handle_repeated2(
#[pyo3(cancel_handle)] _param: String,
#[pyo3(cancel_handle)] _param2: String,
) {
}

#[pyfunction]
fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}

fn main() {}
34 changes: 32 additions & 2 deletions tests/ui/invalid_argument_attributes.stderr
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
error: expected `from_py_with`
error: expected `cancel_handle` or `from_py_with`
--> tests/ui/invalid_argument_attributes.rs:4:29
|
4 | fn invalid_attribute(#[pyo3(get)] param: String) {}
Expand All @@ -10,7 +10,7 @@ error: expected `=`
7 | fn from_py_with_no_value(#[pyo3(from_py_with)] param: String) {}
| ^

error: expected `from_py_with`
error: expected `cancel_handle` or `from_py_with`
--> tests/ui/invalid_argument_attributes.rs:10:31
|
10 | fn from_py_with_string(#[pyo3("from_py_with")] param: String) {}
Expand All @@ -21,3 +21,33 @@ error: expected string literal
|
13 | fn from_py_with_value_not_a_string(#[pyo3(from_py_with = func)] param: String) {}
| ^^^^

error: `from_py_with` may only be specified once per argument
--> tests/ui/invalid_argument_attributes.rs:16:56
|
16 | fn from_py_with_repeated(#[pyo3(from_py_with = "func", from_py_with = "func")] param: String) {}
| ^^^^^^^^^^^^

error: `from_py_with` and `cancel_handle` cannot be specified together
--> tests/ui/invalid_argument_attributes.rs:20:35
|
20 | #[pyo3(from_py_with = "func", cancel_handle)] _param: String,
| ^^^^^^^^^^^^^

error: `cancel_handle` may only be specified once per argument
--> tests/ui/invalid_argument_attributes.rs:25:55
|
25 | async fn cancel_handle_repeated(#[pyo3(cancel_handle, cancel_handle)] _param: String) {}
| ^^^^^^^^^^^^^

error: `cancel_handle` may only be specified once
--> tests/ui/invalid_argument_attributes.rs:30:28
|
30 | #[pyo3(cancel_handle)] _param2: String,
| ^^^^^^^

error: `cancel_handle` attribute can only be used with `async fn`
--> tests/ui/invalid_argument_attributes.rs:35:53
|
35 | fn cancel_handle_synchronous(#[pyo3(cancel_handle)] _param: String) {}
| ^^^^^^

0 comments on commit c5cb1b5

Please sign in to comment.