Skip to content

Commit 07726ae

Browse files
authored
Merge pull request #3609 from wyfo/async_receiver
feat: allow async methods to accept `&self`/`&mut self`
2 parents 4baf023 + f34c70c commit 07726ae

File tree

6 files changed

+167
-22
lines changed

6 files changed

+167
-22
lines changed

guide/src/async-await.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '
3030

3131
As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.
3232

33-
It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*
34-
33+
However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`
3534

3635
## Implicit GIL holding
3736

newsfragments/3609.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow async methods to accept `&self`/`&mut self`

pyo3-macros-backend/src/method.rs

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
use std::fmt::Display;
22

3-
use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
4-
use crate::deprecations::{Deprecation, Deprecations};
5-
use crate::params::impl_arg_params;
6-
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
7-
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
8-
use crate::quotes;
9-
use crate::utils::{self, PythonDoc};
103
use proc_macro2::{Span, TokenStream};
11-
use quote::ToTokens;
12-
use quote::{quote, quote_spanned};
13-
use syn::ext::IdentExt;
14-
use syn::spanned::Spanned;
15-
use syn::{Ident, Result};
4+
use quote::{quote, quote_spanned, ToTokens};
5+
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
6+
7+
use crate::{
8+
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
9+
deprecations::{Deprecation, Deprecations},
10+
params::impl_arg_params,
11+
pyfunction::{
12+
FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute,
13+
},
14+
quotes,
15+
utils::{self, PythonDoc},
16+
};
1617

1718
#[derive(Clone, Debug)]
1819
pub struct FnArg<'a> {
@@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> {
473474
}
474475

475476
let rust_call = |args: Vec<TokenStream>| {
476-
let mut call = quote! { function(#self_arg #(#args),*) };
477-
if self.asyncness.is_some() {
477+
let call = if self.asyncness.is_some() {
478478
let throw_callback = if cancel_handle.is_some() {
479479
quote! { Some(__throw_callback) }
480480
} else {
@@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> {
485485
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
486486
None => quote!(None),
487487
};
488-
call = quote! {{
489-
let future = #call;
488+
let future = match self.tp {
489+
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
490+
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
491+
async move { function(&__guard, #(#args),*).await }
492+
}},
493+
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
494+
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
495+
async move { function(&mut __guard, #(#args),*).await }
496+
}},
497+
_ => quote! { function(#self_arg #(#args),*) },
498+
};
499+
let mut call = quote! {{
500+
let future = #future;
490501
_pyo3::impl_::coroutine::new_coroutine(
491502
_pyo3::intern!(py, stringify!(#python_name)),
492503
#qualname_prefix,
@@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> {
501512
#call
502513
}};
503514
}
504-
}
515+
call
516+
} else {
517+
quote! { function(#self_arg #(#args),*) }
518+
};
505519
quotes::map_result_into_ptr(quotes::ok_wrap(call))
506520
};
507521

src/impl_/coroutine.rs

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1-
use std::future::Future;
1+
use std::{
2+
future::Future,
3+
mem,
4+
ops::{Deref, DerefMut},
5+
};
26

3-
use crate::coroutine::cancel::ThrowCallback;
4-
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
7+
use crate::{
8+
coroutine::{cancel::ThrowCallback, Coroutine},
9+
pyclass::boolean_struct::False,
10+
types::PyString,
11+
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, Python,
12+
};
513

614
pub fn new_coroutine<F, T, E>(
715
name: &PyString,
@@ -16,3 +24,63 @@ where
1624
{
1725
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
1826
}
27+
28+
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
29+
// SAFETY: Py<T> can be casted as *const PyCell<T>
30+
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
31+
}
32+
33+
pub struct RefGuard<T: PyClass>(Py<T>);
34+
35+
impl<T: PyClass> RefGuard<T> {
36+
pub fn new(obj: &PyAny) -> PyResult<Self> {
37+
let owned: Py<T> = obj.extract()?;
38+
mem::forget(owned.try_borrow(obj.py())?);
39+
Ok(RefGuard(owned))
40+
}
41+
}
42+
43+
impl<T: PyClass> Deref for RefGuard<T> {
44+
type Target = T;
45+
fn deref(&self) -> &Self::Target {
46+
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
47+
unsafe { &*get_ptr(&self.0) }
48+
}
49+
}
50+
51+
impl<T: PyClass> Drop for RefGuard<T> {
52+
fn drop(&mut self) {
53+
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
54+
}
55+
}
56+
57+
pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);
58+
59+
impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
60+
pub fn new(obj: &PyAny) -> PyResult<Self> {
61+
let owned: Py<T> = obj.extract()?;
62+
mem::forget(owned.try_borrow_mut(obj.py())?);
63+
Ok(RefMutGuard(owned))
64+
}
65+
}
66+
67+
impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
68+
type Target = T;
69+
fn deref(&self) -> &Self::Target {
70+
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
71+
unsafe { &*get_ptr(&self.0) }
72+
}
73+
}
74+
75+
impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
76+
fn deref_mut(&mut self) -> &mut Self::Target {
77+
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
78+
unsafe { &mut *get_ptr(&self.0) }
79+
}
80+
}
81+
82+
impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
83+
fn drop(&mut self) {
84+
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
85+
}
86+
}

src/pycell.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
516516
#[allow(clippy::useless_conversion)]
517517
offset.try_into().expect("offset should fit in Py_ssize_t")
518518
}
519+
520+
#[cfg(feature = "macros")]
521+
pub(crate) fn release_ref(&self) {
522+
self.borrow_checker().release_borrow();
523+
}
524+
525+
#[cfg(feature = "macros")]
526+
pub(crate) fn release_mut(&self) {
527+
self.borrow_checker().release_borrow_mut();
528+
}
519529
}
520530

521531
impl<T: PyClassImpl> PyCell<T> {

tests/test_coroutine.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,56 @@ fn coroutine_panic() {
234234
py_run!(gil, panic, &handle_windows(test));
235235
})
236236
}
237+
238+
#[test]
239+
fn test_async_method_receiver() {
240+
#[pyclass]
241+
struct Counter(usize);
242+
#[pymethods]
243+
impl Counter {
244+
#[new]
245+
fn new() -> Self {
246+
Self(0)
247+
}
248+
async fn get(&self) -> usize {
249+
self.0
250+
}
251+
async fn incr(&mut self) -> usize {
252+
self.0 += 1;
253+
self.0
254+
}
255+
}
256+
Python::with_gil(|gil| {
257+
let test = r#"
258+
import asyncio
259+
260+
obj = Counter()
261+
coro1 = obj.get()
262+
coro2 = obj.get()
263+
try:
264+
obj.incr() # borrow checking should fail
265+
except RuntimeError as err:
266+
pass
267+
else:
268+
assert False
269+
assert asyncio.run(coro1) == 0
270+
coro2.close()
271+
coro3 = obj.incr()
272+
try:
273+
obj.incr() # borrow checking should fail
274+
except RuntimeError as err:
275+
pass
276+
else:
277+
assert False
278+
try:
279+
obj.get() # borrow checking should fail
280+
except RuntimeError as err:
281+
pass
282+
else:
283+
assert False
284+
assert asyncio.run(coro3) == 1
285+
"#;
286+
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
287+
py_run!(gil, *locals, test);
288+
})
289+
}

0 commit comments

Comments
 (0)