Skip to content

Commit 8df3961

Browse files
committed
Re-use threads when calling closures after releasing GIL.
1 parent 3e2dac8 commit 8df3961

File tree

1 file changed

+80
-3
lines changed

1 file changed

+80
-3
lines changed

src/marker.rs

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ use crate::{ffi, FromPyPointer, IntoPy, Py, PyObject, PyTypeCheck, PyTypeInfo};
5353
use std::ffi::{CStr, CString};
5454
use std::marker::PhantomData;
5555
use std::os::raw::c_int;
56-
use std::thread;
5756

5857
/// A marker token that represents holding the GIL.
5958
///
@@ -316,16 +315,94 @@ impl<'py> Python<'py> {
316315
F: Send + FnOnce() -> T,
317316
T: Send,
318317
{
318+
use parking_lot::{const_mutex, Mutex};
319+
use std::mem::{transmute, ManuallyDrop, MaybeUninit};
320+
use std::panic::{catch_unwind, AssertUnwindSafe};
321+
use std::sync::mpsc::{sync_channel, SendError, SyncSender};
322+
use std::thread::{spawn, Result};
323+
use std::time::Duration;
324+
319325
// Use a guard pattern to handle reacquiring the GIL,
320326
// so that the GIL will be reacquired even if `f` panics.
321327
// The `Send` bound on the closure prevents the user from
322328
// transferring the `Python` token into the closure.
323329
let _guard = unsafe { SuspendGIL::new() };
324330

325331
// To close soundness loopholes w.r.t. `send_wrapper` or `scoped-tls`,
326-
// we run the closure on a newly created thread so that it cannot
332+
// we run the closure on a separate thread so that it cannot
327333
// access thread-local storage from the current thread.
328-
thread::scope(|s| s.spawn(f).join().unwrap())
334+
335+
// Construct a task
336+
struct Task(*mut dyn FnMut());
337+
unsafe impl Send for Task {}
338+
339+
let mut f = ManuallyDrop::new(f);
340+
let mut result = MaybeUninit::<Result<T>>::uninit();
341+
342+
let (result_sender, result_receiver) = sync_channel(0);
343+
344+
let mut task = || {
345+
// SAFETY: `F` is `Send` and we ensure that this closure is called at most once
346+
let f = unsafe { ManuallyDrop::take(&mut f) };
347+
348+
result.write(catch_unwind(AssertUnwindSafe(f)));
349+
350+
result_sender.send(()).unwrap();
351+
};
352+
// SAFETY: the current thread will block until the closure has returned
353+
let task = Task(unsafe { transmute(&mut task as &mut dyn FnMut()) });
354+
355+
// Enqueue task and spawn thread if necessary
356+
static THREADS: Mutex<Vec<SyncSender<Task>>> = const_mutex(Vec::new());
357+
358+
enum State {
359+
Pending(Task),
360+
Dispatched(SyncSender<Task>),
361+
}
362+
363+
let mut state = State::Pending(task);
364+
365+
while let Some(task_sender) = THREADS.lock().pop() {
366+
match state {
367+
State::Pending(task) => match task_sender.send(task) {
368+
Ok(()) => {
369+
state = State::Dispatched(task_sender);
370+
break;
371+
}
372+
Err(SendError(task)) => {
373+
state = State::Pending(task);
374+
continue;
375+
}
376+
},
377+
State::Dispatched(_sender) => unreachable!(),
378+
}
379+
}
380+
381+
let task_sender = match state {
382+
State::Pending(task) => {
383+
let (task_sender, task_receiver) = sync_channel::<Task>(0);
384+
385+
spawn(move || {
386+
while let Ok(task) = task_receiver.recv_timeout(Duration::from_secs(60)) {
387+
// SAFETY: all data accessed by `task` will stay alive until it completes
388+
unsafe { (*task.0)() };
389+
}
390+
});
391+
392+
task_sender.send(task).unwrap();
393+
394+
task_sender
395+
}
396+
State::Dispatched(task_sender) => task_sender,
397+
};
398+
399+
// Wait for completion and read result
400+
result_receiver.recv().unwrap();
401+
402+
THREADS.lock().push(task_sender);
403+
404+
// SAFETY: the task completed and hence initialized `result`
405+
unsafe { result.assume_init().unwrap() }
329406
}
330407

331408
/// Evaluates a Python expression in the given context and returns the result.

0 commit comments

Comments
 (0)