Skip to content

Commit 639a1aa

Browse files
committed
feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions
1 parent 07726ae commit 639a1aa

File tree

9 files changed

+118
-16
lines changed

9 files changed

+118
-16
lines changed

newsfragments/3610.added.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions

pyo3-macros-backend/src/attributes.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use syn::{
99
};
1010

1111
pub mod kw {
12+
syn::custom_keyword!(allow_threads);
1213
syn::custom_keyword!(annotation);
1314
syn::custom_keyword!(attribute);
1415
syn::custom_keyword!(cancel_handle);

pyo3-macros-backend/src/method.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use std::fmt::Display;
22

33
use proc_macro2::{Span, TokenStream};
4-
use quote::{quote, quote_spanned, ToTokens};
4+
use quote::{format_ident, quote, quote_spanned, ToTokens};
55
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};
66

77
use crate::{
8+
attributes,
89
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
910
deprecations::{Deprecation, Deprecations},
1011
params::impl_arg_params,
@@ -241,6 +242,7 @@ pub struct FnSpec<'a> {
241242
pub asyncness: Option<syn::Token![async]>,
242243
pub unsafety: Option<syn::Token![unsafe]>,
243244
pub deprecations: Deprecations,
245+
pub allow_threads: Option<attributes::kw::allow_threads>,
244246
}
245247

246248
pub fn get_return_info(output: &syn::ReturnType) -> syn::Type {
@@ -284,6 +286,7 @@ impl<'a> FnSpec<'a> {
284286
text_signature,
285287
name,
286288
signature,
289+
allow_threads,
287290
..
288291
} = options;
289292

@@ -331,6 +334,7 @@ impl<'a> FnSpec<'a> {
331334
asyncness: sig.asyncness,
332335
unsafety: sig.unsafety,
333336
deprecations,
337+
allow_threads,
334338
})
335339
}
336340

@@ -474,6 +478,7 @@ impl<'a> FnSpec<'a> {
474478
}
475479

476480
let rust_call = |args: Vec<TokenStream>| {
481+
let allow_threads = self.allow_threads.is_some();
477482
let call = if self.asyncness.is_some() {
478483
let throw_callback = if cancel_handle.is_some() {
479484
quote! { Some(__throw_callback) }
@@ -502,6 +507,7 @@ impl<'a> FnSpec<'a> {
502507
_pyo3::intern!(py, stringify!(#python_name)),
503508
#qualname_prefix,
504509
#throw_callback,
510+
#allow_threads,
505511
async move { _pyo3::impl_::wrap::OkWrap::wrap(future.await) },
506512
)
507513
}};
@@ -513,6 +519,25 @@ impl<'a> FnSpec<'a> {
513519
}};
514520
}
515521
call
522+
} else if allow_threads {
523+
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
524+
(quote!(), quote!())
525+
} else {
526+
(quote!(__self), quote! { let __self = #self_arg; })
527+
};
528+
let arg_names: Vec<Ident> = (0..args.len())
529+
.map(|i| format_ident!("__arg{}", i))
530+
.collect();
531+
let arg_decls: Vec<TokenStream> = args
532+
.into_iter()
533+
.zip(&arg_names)
534+
.map(|(arg, name)| quote! { let #name = #arg; })
535+
.collect();
536+
quote! {{
537+
#self_arg_decl
538+
#(#arg_decls)*
539+
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
540+
}}
516541
} else {
517542
quote! { function(#self_arg #(#args),*) }
518543
};

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
9191
pub signature: Option<SignatureAttribute>,
9292
pub text_signature: Option<TextSignatureAttribute>,
9393
pub krate: Option<CrateAttribute>,
94+
pub allow_threads: Option<attributes::kw::allow_threads>,
9495
}
9596

9697
impl Parse for PyFunctionOptions {
@@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {
99100

100101
while !input.is_empty() {
101102
let lookahead = input.lookahead1();
102-
if lookahead.peek(attributes::kw::name)
103+
if lookahead.peek(attributes::kw::allow_threads)
104+
|| lookahead.peek(attributes::kw::name)
103105
|| lookahead.peek(attributes::kw::pass_module)
104106
|| lookahead.peek(attributes::kw::signature)
105107
|| lookahead.peek(attributes::kw::text_signature)
@@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
121123
}
122124

123125
pub enum PyFunctionOption {
126+
AllowThreads(attributes::kw::allow_threads),
124127
Name(NameAttribute),
125128
PassModule(attributes::kw::pass_module),
126129
Signature(SignatureAttribute),
@@ -131,7 +134,9 @@ pub enum PyFunctionOption {
131134
impl Parse for PyFunctionOption {
132135
fn parse(input: ParseStream<'_>) -> Result<Self> {
133136
let lookahead = input.lookahead1();
134-
if lookahead.peek(attributes::kw::name) {
137+
if lookahead.peek(attributes::kw::allow_threads) {
138+
input.parse().map(PyFunctionOption::AllowThreads)
139+
} else if lookahead.peek(attributes::kw::name) {
135140
input.parse().map(PyFunctionOption::Name)
136141
} else if lookahead.peek(attributes::kw::pass_module) {
137142
input.parse().map(PyFunctionOption::PassModule)
@@ -171,6 +176,7 @@ impl PyFunctionOptions {
171176
}
172177
for attr in attrs {
173178
match attr {
179+
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
174180
PyFunctionOption::Name(name) => set_option!(name),
175181
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
176182
PyFunctionOption::Signature(signature) => set_option!(signature),
@@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
198204
) -> syn::Result<TokenStream> {
199205
check_generic(&func.sig)?;
200206
let PyFunctionOptions {
207+
allow_threads,
201208
pass_module,
202209
name,
203210
signature,
@@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
247254
signature,
248255
output: ty,
249256
text_signature,
257+
allow_threads,
250258
asyncness: func.sig.asyncness,
251259
unsafety: func.sig.unsafety,
252260
deprecations: Deprecations::new(),

pyo3-macros/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
125125
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
126126
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
127127
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
128+
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
128129
///
129130
/// For more on exposing functions see the [function section of the guide][1].
130131
///

src/coroutine.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub struct Coroutine {
3232
name: Option<Py<PyString>>,
3333
qualname_prefix: Option<&'static str>,
3434
throw_callback: Option<ThrowCallback>,
35+
allow_threads: bool,
3536
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
3637
waker: Option<Arc<AsyncioWaker>>,
3738
}
@@ -47,6 +48,7 @@ impl Coroutine {
4748
name: Option<Py<PyString>>,
4849
qualname_prefix: Option<&'static str>,
4950
throw_callback: Option<ThrowCallback>,
51+
allow_threads: bool,
5052
future: F,
5153
) -> Self
5254
where
@@ -63,6 +65,7 @@ impl Coroutine {
6365
name,
6466
qualname_prefix,
6567
throw_callback,
68+
allow_threads,
6669
future: Some(Box::pin(wrap)),
6770
waker: None,
6871
}
@@ -96,7 +99,13 @@ impl Coroutine {
9699
let waker = Waker::from(self.waker.clone().unwrap());
97100
// poll the Rust future and forward its results if ready
98101
// polling is UnwindSafe because the future is dropped in case of panic
99-
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
102+
let poll = || {
103+
if self.allow_threads {
104+
py.allow_threads(|| future_rs.as_mut().poll(&mut Context::from_waker(&waker)))
105+
} else {
106+
future_rs.as_mut().poll(&mut Context::from_waker(&waker))
107+
}
108+
};
100109
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
101110
Ok(Poll::Ready(res)) => {
102111
self.close();

src/gil.rs

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
//! Interaction with Python's global interpreter lock
22
3-
use crate::impl_::not_send::{NotSend, NOT_SEND};
4-
use crate::{ffi, Python};
5-
use parking_lot::{const_mutex, Mutex, Once};
6-
use std::cell::Cell;
73
#[cfg(debug_assertions)]
84
use std::cell::RefCell;
95
#[cfg(not(debug_assertions))]
106
use std::cell::UnsafeCell;
11-
use std::{mem, ptr::NonNull};
7+
use std::{cell::Cell, mem, ptr::NonNull};
8+
9+
use parking_lot::{const_mutex, Mutex, Once};
10+
11+
use crate::{
12+
ffi,
13+
impl_::not_send::{NotSend, NOT_SEND},
14+
Python,
15+
};
1216

1317
static START: Once = Once::new();
1418

@@ -506,11 +510,13 @@ fn decrement_gil_count() {
506510

507511
#[cfg(test)]
508512
mod tests {
509-
use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
510-
use crate::{ffi, gil, PyObject, Python, ToPyObject};
513+
use std::ptr::NonNull;
514+
511515
#[cfg(not(target_arch = "wasm32"))]
512516
use parking_lot::{const_mutex, Condvar, Mutex};
513-
use std::ptr::NonNull;
517+
518+
use super::{gil_is_acquired, GILPool, GIL_COUNT, OWNED_OBJECTS, POOL};
519+
use crate::{ffi, gil, PyObject, Python, ToPyObject};
514520

515521
fn get_object(py: Python<'_>) -> PyObject {
516522
// Convenience function for getting a single unique object, using `new_pool` so as to leave
@@ -786,9 +792,10 @@ mod tests {
786792
#[test]
787793
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
788794
fn test_clone_without_gil() {
789-
use crate::{Py, PyAny};
790795
use std::{sync::Arc, thread};
791796

797+
use crate::{Py, PyAny};
798+
792799
// Some events for synchronizing
793800
static GIL_ACQUIRED: Event = Event::new();
794801
static OBJECT_CLONED: Event = Event::new();
@@ -851,9 +858,10 @@ mod tests {
851858
#[test]
852859
#[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
853860
fn test_clone_in_other_thread() {
854-
use crate::Py;
855861
use std::{sync::Arc, thread};
856862

863+
use crate::Py;
864+
857865
// Some events for synchronizing
858866
static OBJECT_CLONED: Event = Event::new();
859867

@@ -925,4 +933,46 @@ mod tests {
925933
POOL.update_counts(py);
926934
})
927935
}
936+
937+
#[cfg(feature = "macros")]
938+
#[test]
939+
fn allow_threads_fn() {
940+
#[crate::pyfunction(allow_threads, crate = "crate")]
941+
fn without_gil() {
942+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
943+
}
944+
Python::with_gil(|gil| {
945+
let without_gil = crate::wrap_pyfunction!(without_gil, gil).unwrap();
946+
crate::py_run!(gil, without_gil, "without_gil()");
947+
})
948+
}
949+
950+
#[cfg(feature = "macros")]
951+
#[test]
952+
fn allow_threads_async_fn() {
953+
#[crate::pyfunction(allow_threads, crate = "crate")]
954+
async fn without_gil() {
955+
use std::task::Poll;
956+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
957+
let mut ready = false;
958+
futures::future::poll_fn(|cx| {
959+
if ready {
960+
return Poll::Ready(());
961+
}
962+
ready = true;
963+
cx.waker().wake_by_ref();
964+
Poll::Pending
965+
})
966+
.await;
967+
GIL_COUNT.with(|c| assert_eq!(c.get(), 0));
968+
}
969+
Python::with_gil(|gil| {
970+
let without_gil = crate::wrap_pyfunction!(without_gil, gil).unwrap();
971+
crate::py_run!(
972+
gil,
973+
without_gil,
974+
"import asyncio; asyncio.run(without_gil())"
975+
);
976+
})
977+
}
928978
}

src/impl_/coroutine.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,21 @@ pub fn new_coroutine<F, T, E>(
1515
name: &PyString,
1616
qualname_prefix: Option<&'static str>,
1717
throw_callback: Option<ThrowCallback>,
18+
allow_threads: bool,
1819
future: F,
1920
) -> Coroutine
2021
where
2122
F: Future<Output = Result<T, E>> + Send + 'static,
2223
T: IntoPy<PyObject>,
2324
E: Into<PyErr>,
2425
{
25-
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
26+
Coroutine::new(
27+
Some(name.into()),
28+
qualname_prefix,
29+
throw_callback,
30+
allow_threads,
31+
future,
32+
)
2633
}
2734

2835
fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {

tests/ui/invalid_pyfunction_signatures.stderr

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ error: expected argument from function definition `y` but got argument `x`
1616
13 | #[pyo3(signature = (x))]
1717
| ^
1818

19-
error: expected one of: `name`, `pass_module`, `signature`, `text_signature`, `crate`
19+
error: expected one of: `allow_threads`, `name`, `pass_module`, `signature`, `text_signature`, `crate`
2020
--> tests/ui/invalid_pyfunction_signatures.rs:18:14
2121
|
2222
18 | #[pyfunction(x)]

0 commit comments

Comments
 (0)