Skip to content

Commit 924ab8c

Browse files
committed
Re-use threads when calling closures after releasing GIL.
1 parent 3e2dac8 commit 924ab8c

File tree

1 file changed

+90
-3
lines changed

1 file changed

+90
-3
lines changed

src/marker.rs

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ use crate::{ffi, FromPyPointer, IntoPy, Py, PyObject, PyTypeCheck, PyTypeInfo};
5353
use std::ffi::{CStr, CString};
5454
use std::marker::PhantomData;
5555
use std::os::raw::c_int;
56-
use std::thread;
5756

5857
/// A marker token that represents holding the GIL.
5958
///
@@ -316,16 +315,104 @@ impl<'py> Python<'py> {
316315
F: Send + FnOnce() -> T,
317316
T: Send,
318317
{
318+
use std::mem::transmute;
319+
use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
320+
use std::sync::mpsc::{sync_channel, SendError, SyncSender};
321+
use std::thread::{spawn, Result};
322+
use std::time::Duration;
323+
324+
use parking_lot::{const_mutex, Mutex};
325+
326+
use crate::impl_::panic::PanicTrap;
327+
319328
// Use a guard pattern to handle reacquiring the GIL,
320329
// so that the GIL will be reacquired even if `f` panics.
321330
// The `Send` bound on the closure prevents the user from
322331
// transferring the `Python` token into the closure.
323332
let _guard = unsafe { SuspendGIL::new() };
324333

325334
// To close soundness loopholes w.r.t. `send_wrapper` or `scoped-tls`,
326-
// we run the closure on a newly created thread so that it cannot
335+
// we run the closure on a separate thread so that it cannot
327336
// access thread-local storage from the current thread.
328-
thread::scope(|s| s.spawn(f).join().unwrap())
337+
338+
// 1. Construct a task
339+
struct Task(*mut dyn FnMut());
340+
unsafe impl Send for Task {}
341+
342+
let (result_sender, result_receiver) = sync_channel::<Result<T>>(0);
343+
344+
let mut f = Some(f);
345+
346+
let mut task = || {
347+
let f = f.take().unwrap();
348+
349+
let result = catch_unwind(AssertUnwindSafe(f));
350+
351+
result_sender.send(result).unwrap();
352+
};
353+
354+
// SAFETY: the current thread will block until the closure has returned
355+
let task = Task(unsafe { transmute(&mut task as &mut dyn FnMut()) });
356+
357+
// 2. Dispatch task to waiting thread, spawn new thread if necessary
358+
let trap = PanicTrap::new(
359+
"allow_threads panicked while stack data was accessed by another thread which is a bug",
360+
);
361+
362+
static THREADS: Mutex<Vec<SyncSender<Task>>> = const_mutex(Vec::new());
363+
364+
enum State {
365+
Pending(Task),
366+
Dispatched(SyncSender<Task>),
367+
}
368+
369+
let mut state = State::Pending(task);
370+
371+
while let Some(task_sender) = THREADS.lock().pop() {
372+
match state {
373+
State::Pending(task) => match task_sender.send(task) {
374+
Ok(()) => {
375+
state = State::Dispatched(task_sender);
376+
break;
377+
}
378+
Err(SendError(task)) => {
379+
state = State::Pending(task);
380+
continue;
381+
}
382+
},
383+
State::Dispatched(_task_sender) => unreachable!(),
384+
}
385+
}
386+
387+
let task_sender = match state {
388+
State::Pending(task) => {
389+
let (task_sender, task_receiver) = sync_channel::<Task>(0);
390+
391+
spawn(move || {
392+
while let Ok(task) = task_receiver.recv_timeout(Duration::from_secs(60)) {
393+
// SAFETY: all data accessed by `task` will stay alive until it completes
394+
unsafe { (*task.0)() };
395+
}
396+
});
397+
398+
task_sender.send(task).unwrap();
399+
400+
task_sender
401+
}
402+
State::Dispatched(task_sender) => task_sender,
403+
};
404+
405+
// 3. Wait for completion and check result
406+
let result = result_receiver.recv().unwrap();
407+
408+
trap.disarm();
409+
410+
THREADS.lock().push(task_sender);
411+
412+
match result {
413+
Ok(result) => result,
414+
Err(payload) => resume_unwind(payload),
415+
}
329416
}
330417

331418
/// Evaluates a Python expression in the given context and returns the result.

0 commit comments

Comments
 (0)