Skip to content

Commit 2752929

Browse files
committed
feat: allow async methods to accept &self/&mut self
1 parent b86bce4 commit 2752929

File tree

6 files changed

+122
-9
lines changed

6 files changed

+122
-9
lines changed

guide/src/async-await.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '
3030

3131
As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.
3232

33-
It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*
34-
33+
However, an exception is done for method receiver, so async methods can accept `&self`/`&mut self`
3534

3635
## Implicit GIL holding
3736

newsfragments/3609.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow async methods to accept `&self`/`&mut self`

pyo3-macros-backend/src/method.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -472,8 +472,7 @@ impl<'a> FnSpec<'a> {
472472
}
473473

474474
let rust_call = |args: Vec<TokenStream>| {
475-
let mut call = quote! { function(#self_arg #(#args),*) };
476-
if self.asyncness.is_some() {
475+
let call = if self.asyncness.is_some() {
477476
let throw_callback = if cancel_handle.is_some() {
478477
quote! { Some(__throw_callback) }
479478
} else {
@@ -484,8 +483,23 @@ impl<'a> FnSpec<'a> {
484483
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
485484
None => quote!(None),
486485
};
487-
call = quote! {{
488-
let future = #call;
486+
let future = match self.tp {
487+
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {
488+
_pyo3::impl_::coroutine::ref_method_future(
489+
py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf),
490+
move |__self| function(__self, #(#args),*)
491+
)?
492+
},
493+
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {
494+
_pyo3::impl_::coroutine::mut_method_future(
495+
py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf),
496+
move |__self| function(__self, #(#args),*)
497+
)?
498+
},
499+
_ => quote! { function(#self_arg #(#args),*) },
500+
};
501+
let mut call = quote! {{
502+
let future = #future;
489503
_pyo3::impl_::coroutine::new_coroutine(
490504
_pyo3::intern!(py, stringify!(#python_name)),
491505
#qualname_prefix,
@@ -500,7 +514,10 @@ impl<'a> FnSpec<'a> {
500514
#call
501515
}};
502516
}
503-
}
517+
call
518+
} else {
519+
quote! { function(#self_arg #(#args),*) }
520+
};
504521
quotes::map_result_into_ptr(quotes::ok_wrap(call))
505522
};
506523

src/impl_/coroutine.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
use std::future::Future;
2+
use std::mem;
23

34
use crate::coroutine::cancel::ThrowCallback;
4-
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
5+
use crate::pyclass::boolean_struct::False;
6+
use crate::{
7+
coroutine::Coroutine, types::PyString, IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject,
8+
PyRef, PyRefMut, PyResult, Python,
9+
};
510

611
pub fn new_coroutine<F, T, E>(
712
name: &PyString,
@@ -16,3 +21,46 @@ where
1621
{
1722
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
1823
}
24+
25+
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
26+
// SAFETY: Py<T> can be casted as *const PyCell<T>
27+
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
28+
}
29+
30+
struct RefGuard<T: PyClass>(Py<T>);
31+
32+
impl<T: PyClass> Drop for RefGuard<T> {
33+
fn drop(&mut self) {
34+
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
35+
}
36+
}
37+
38+
pub unsafe fn ref_method_future<'a, T: PyClass, F: Future + 'a>(
39+
self_: &PyAny,
40+
fut: impl FnOnce(&'a T) -> F,
41+
) -> PyResult<impl Future<Output = F::Output>> {
42+
let ref_: PyRef<'_, T> = self_.extract()?;
43+
// SAFETY: `PyRef::as_ptr` returns a borrowed reference
44+
let guard = RefGuard(unsafe { Py::<T>::from_borrowed_ptr(self_.py(), ref_.as_ptr()) });
45+
mem::forget(ref_);
46+
Ok(async move { fut(unsafe { &*get_ptr(&guard.0) }).await })
47+
}
48+
49+
struct RefMutGuard<T: PyClass>(Py<T>);
50+
51+
impl<T: PyClass> Drop for RefMutGuard<T> {
52+
fn drop(&mut self) {
53+
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
54+
}
55+
}
56+
57+
pub fn mut_method_future<'a, T: PyClass<Frozen = False>, F: Future + 'a>(
58+
self_: &PyAny,
59+
fut: impl FnOnce(&'a mut T) -> F,
60+
) -> PyResult<impl Future<Output = F::Output>> {
61+
let mut_: PyRefMut<'_, T> = self_.extract()?;
62+
// SAFETY: `PyRefMut::as_ptr` returns a borrowed reference
63+
let guard = RefMutGuard(unsafe { Py::<T>::from_borrowed_ptr(self_.py(), mut_.as_ptr()) });
64+
mem::forget(mut_);
65+
Ok(async move { fut(unsafe { &mut *get_ptr(&guard.0) }).await })
66+
}

src/pycell.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
516516
#[allow(clippy::useless_conversion)]
517517
offset.try_into().expect("offset should fit in Py_ssize_t")
518518
}
519+
520+
#[cfg(feature = "macros")]
521+
pub(crate) fn release_ref(&self) {
522+
self.borrow_checker().release_borrow();
523+
}
524+
525+
#[cfg(feature = "macros")]
526+
pub(crate) fn release_mut(&self) {
527+
self.borrow_checker().release_borrow_mut();
528+
}
519529
}
520530

521531
impl<T: PyClassImpl> PyCell<T> {

tests/test_coroutine.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ fn test_coroutine_qualname() {
4545
fn new() -> Self {
4646
Self
4747
}
48-
// TODO use &self when possible
4948
async fn my_method(_self: Py<Self>) {}
5049
#[classmethod]
5150
async fn my_classmethod(_cls: Py<PyType>) {}
@@ -173,3 +172,42 @@ fn coroutine_cancel_handle() {
173172
.unwrap();
174173
})
175174
}
175+
176+
#[test]
177+
fn test_async_method_receiver() {
178+
#[pyclass]
179+
struct MyClass;
180+
#[pymethods]
181+
impl MyClass {
182+
#[new]
183+
fn new() -> Self {
184+
Self
185+
}
186+
async fn my_method(&self) {}
187+
async fn my_mut_method(&mut self) {}
188+
}
189+
Python::with_gil(|gil| {
190+
let test = r#"
191+
obj = MyClass()
192+
coro1 = obj.my_method()
193+
coro2 = obj.my_method()
194+
try:
195+
assert obj.my_mut_method() == 42, "borrow checking should fail"
196+
except RuntimeError as err:
197+
pass
198+
coro1.close()
199+
coro2.close()
200+
coro3 = obj.my_mut_method()
201+
try:
202+
assert obj.my_mut_method() == 42, "borrow checking should fail"
203+
except RuntimeError as err:
204+
pass
205+
try:
206+
assert obj.my_method() == 42, "borrow checking should fail"
207+
except RuntimeError as err:
208+
pass
209+
"#;
210+
let locals = [("MyClass", gil.get_type::<MyClass>())].into_py_dict(gil);
211+
py_run!(gil, *locals, test);
212+
})
213+
}

0 commit comments

Comments
 (0)