Skip to content

Commit 77b24fa

Browse files
committed
feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions
1 parent 6ad6fea commit 77b24fa

File tree

9 files changed

+99
-8
lines changed

9 files changed

+99
-8
lines changed

newsfragments/3610.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions

pyo3-macros-backend/src/attributes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use syn::{
99
};
1010

1111
pub mod kw {
12+
syn::custom_keyword!(allow_threads);
1213
syn::custom_keyword!(annotation);
1314
syn::custom_keyword!(attribute);
1415
syn::custom_keyword!(cancel_handle);

pyo3-macros-backend/src/method.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ use crate::deprecations::{Deprecation, Deprecations};
55
use crate::params::impl_arg_params;
66
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
77
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
8-
use crate::quotes;
98
use crate::utils::{self, PythonDoc};
9+
use crate::{attributes, quotes};
1010
use proc_macro2::{Span, TokenStream};
11-
use quote::ToTokens;
11+
use quote::{format_ident, ToTokens};
1212
use quote::{quote, quote_spanned};
1313
use syn::ext::IdentExt;
1414
use syn::spanned::Spanned;
@@ -239,6 +239,7 @@ pub struct FnSpec<'a> {
239239
pub asyncness: Option<syn::Token![async]>,
240240
pub unsafety: Option<syn::Token![unsafe]>,
241241
pub deprecations: Deprecations,
242+
pub allow_threads: Option<attributes::kw::allow_threads>,
242243
}
243244

244245
pub fn get_return_info(output: &syn::ReturnType) -> syn::Type {
@@ -282,6 +283,7 @@ impl<'a> FnSpec<'a> {
282283
text_signature,
283284
name,
284285
signature,
286+
allow_threads,
285287
..
286288
} = options;
287289

@@ -329,6 +331,7 @@ impl<'a> FnSpec<'a> {
329331
asyncness: sig.asyncness,
330332
unsafety: sig.unsafety,
331333
deprecations,
334+
allow_threads,
332335
})
333336
}
334337

@@ -472,6 +475,7 @@ impl<'a> FnSpec<'a> {
472475
}
473476

474477
let rust_call = |args: Vec<TokenStream>| {
478+
let allow_threads = self.allow_threads.is_some();
475479
let call = if self.asyncness.is_some() {
476480
let throw_callback = if cancel_handle.is_some() {
477481
quote! { Some(__throw_callback) }
@@ -504,6 +508,7 @@ impl<'a> FnSpec<'a> {
504508
_pyo3::intern!(py, stringify!(#python_name)),
505509
#qualname_prefix,
506510
#throw_callback,
511+
#allow_threads,
507512
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
508513
)
509514
}};
@@ -515,6 +520,25 @@ impl<'a> FnSpec<'a> {
515520
}};
516521
}
517522
call
523+
} else if allow_threads {
524+
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
525+
(quote!(), quote!())
526+
} else {
527+
(quote!(__self), quote! { let __self = #self_arg; })
528+
};
529+
let arg_names: Vec<Ident> = (0..args.len())
530+
.map(|i| format_ident!("__arg{}", i))
531+
.collect();
532+
let arg_decls: Vec<TokenStream> = args
533+
.into_iter()
534+
.zip(&arg_names)
535+
.map(|(arg, name)| quote! { let #name = #arg; })
536+
.collect();
537+
quote! {{
538+
#self_arg_decl
539+
#(#arg_decls)*
540+
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
541+
}}
518542
} else {
519543
quote! { function(#self_arg #(#args),*) }
520544
};

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
9191
pub signature: Option<SignatureAttribute>,
9292
pub text_signature: Option<TextSignatureAttribute>,
9393
pub krate: Option<CrateAttribute>,
94+
pub allow_threads: Option<attributes::kw::allow_threads>,
9495
}
9596

9697
impl Parse for PyFunctionOptions {
@@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {
99100

100101
while !input.is_empty() {
101102
let lookahead = input.lookahead1();
102-
if lookahead.peek(attributes::kw::name)
103+
if lookahead.peek(attributes::kw::allow_threads)
104+
|| lookahead.peek(attributes::kw::name)
103105
|| lookahead.peek(attributes::kw::pass_module)
104106
|| lookahead.peek(attributes::kw::signature)
105107
|| lookahead.peek(attributes::kw::text_signature)
@@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
121123
}
122124

123125
pub enum PyFunctionOption {
126+
AllowThreads(attributes::kw::allow_threads),
124127
Name(NameAttribute),
125128
PassModule(attributes::kw::pass_module),
126129
Signature(SignatureAttribute),
@@ -131,7 +134,9 @@ pub enum PyFunctionOption {
131134
impl Parse for PyFunctionOption {
132135
fn parse(input: ParseStream<'_>) -> Result<Self> {
133136
let lookahead = input.lookahead1();
134-
if lookahead.peek(attributes::kw::name) {
137+
if lookahead.peek(attributes::kw::allow_threads) {
138+
input.parse().map(PyFunctionOption::AllowThreads)
139+
} else if lookahead.peek(attributes::kw::name) {
135140
input.parse().map(PyFunctionOption::Name)
136141
} else if lookahead.peek(attributes::kw::pass_module) {
137142
input.parse().map(PyFunctionOption::PassModule)
@@ -171,6 +176,7 @@ impl PyFunctionOptions {
171176
}
172177
for attr in attrs {
173178
match attr {
179+
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
174180
PyFunctionOption::Name(name) => set_option!(name),
175181
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
176182
PyFunctionOption::Signature(signature) => set_option!(signature),
@@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
198204
) -> syn::Result<TokenStream> {
199205
check_generic(&func.sig)?;
200206
let PyFunctionOptions {
207+
allow_threads,
201208
pass_module,
202209
name,
203210
signature,
@@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
247254
signature,
248255
output: ty,
249256
text_signature,
257+
allow_threads,
250258
asyncness: func.sig.asyncness,
251259
unsafety: func.sig.unsafety,
252260
deprecations: Deprecations::new(),

pyo3-macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
125125
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
126126
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
127127
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
128+
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
128129
///
129130
/// For more on exposing functions see the [function section of the guide][1].
130131
///

src/coroutine.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub struct Coroutine {
3434
name: Option<Py<PyString>>,
3535
qualname_prefix: Option<&'static str>,
3636
throw_callback: Option<ThrowCallback>,
37+
allow_threads: bool,
3738
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
3839
waker: Option<Arc<AsyncioWaker>>,
3940
}
@@ -49,6 +50,7 @@ impl Coroutine {
4950
name: Option<Py<PyString>>,
5051
qualname_prefix: Option<&'static str>,
5152
throw_callback: Option<ThrowCallback>,
53+
allow_threads: bool,
5254
future: F,
5355
) -> Self
5456
where
@@ -65,6 +67,7 @@ impl Coroutine {
6567
name,
6668
qualname_prefix,
6769
throw_callback,
70+
allow_threads,
6871
future: Some(Box::pin(wrap)),
6972
waker: None,
7073
}
@@ -98,7 +101,13 @@ impl Coroutine {
98101
let waker = Waker::from(self.waker.clone().unwrap());
99102
// poll the Rust future and forward its results if ready
100103
// polling is UnwindSafe because the future is dropped in case of panic
101-
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
104+
let poll = || {
105+
if self.allow_threads {
106+
py.allow_threads(|| future_rs.as_mut().poll(&mut Context::from_waker(&waker)))
107+
} else {
108+
future_rs.as_mut().poll(&mut Context::from_waker(&waker))
109+
}
110+
};
102111
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
103112
Ok(Poll::Ready(res)) => {
104113
self.close();

src/gil.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ fn decrement_gil_count() {
507507
#[cfg(test)]
508508
mod tests {
509509
use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
510-
use crate::{ffi, gil, PyObject, Python, ToPyObject};
510+
use crate::{ffi, gil, py_run, wrap_pyfunction, PyObject, Python, ToPyObject};
511511
#[cfg(not(target_arch = "wasm32"))]
512512
use parking_lot::{const_mutex, Condvar, Mutex};
513513
use std::ptr::NonNull;
@@ -925,4 +925,44 @@ mod tests {
925925
POOL.update_counts(py);
926926
})
927927
}
928+
929+
#[test]
930+
fn allow_threads_fn() {
931+
#[crate::pyfunction(allow_threads, crate = "crate")]
932+
fn without_gil() {
933+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
934+
}
935+
Python::with_gil(|gil| {
936+
let without_gil = wrap_pyfunction!(without_gil, gil).unwrap();
937+
py_run!(gil, without_gil, "without_gil()");
938+
})
939+
}
940+
941+
#[test]
942+
fn allow_threads_async_fn() {
943+
#[crate::pyfunction(allow_threads, crate = "crate")]
944+
async fn without_gil() {
945+
use std::task::Poll;
946+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
947+
let mut ready = false;
948+
futures::future::poll_fn(|cx| {
949+
if ready {
950+
return Poll::Ready(());
951+
}
952+
ready = true;
953+
cx.waker().wake_by_ref();
954+
Poll::Pending
955+
})
956+
.await;
957+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
958+
}
959+
Python::with_gil(|gil| {
960+
let without_gil = wrap_pyfunction!(without_gil, gil).unwrap();
961+
py_run!(
962+
gil,
963+
without_gil,
964+
"import asyncio; asyncio.run(without_gil())"
965+
);
966+
})
967+
}
928968
}

src/impl_/coroutine.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@ pub fn new_coroutine<F, T, E>(
1212
name: &PyString,
1313
qualname_prefix: Option<&'static str>,
1414
throw_callback: Option<ThrowCallback>,
15+
allow_threads: bool,
1516
future: F,
1617
) -> Coroutine
1718
where
1819
F: Future<Output = Result<T, E>> + Send + 'static,
1920
T: IntoPy<PyObject>,
2021
E: Into<PyErr>,
2122
{
22-
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
23+
Coroutine::new(
24+
Some(name.into()),
25+
qualname_prefix,
26+
throw_callback,
27+
allow_threads,
28+
future,
29+
)
2330
}
2431

2532
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {

tests/ui/invalid_pyfunction_signatures.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x`
1616
13 | #[pyo3(signature = (x))]
1717
| ^
1818

19-
error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`
19+
error: expected one of: `allow_threads`, `name`, `pass_module`, `signature`, `text_signature`, `crate`
2020
--> tests/ui/invalid_pyfunction_signatures.rs:18:14
2121
|
2222
18 | #[pyfunction(x)]

0 commit comments

Comments
 (0)