diff --git a/Cargo.toml b/Cargo.toml index 91999409..56d62232 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,9 @@ wasm-bindgen-test = "0.3" [build-dependencies] rustversion = "1.0" +[patch.crates-io] +tokio = {git = "https://github.com/tokio-rs/tokio", branch = "v0.2.x"} + [profile.bench] codegen-units = 1 debug = 2 diff --git a/src/pool/thread.rs b/src/pool/thread.rs index c917d9f9..536e0b1c 100644 --- a/src/pool/thread.rs +++ b/src/pool/thread.rs @@ -79,13 +79,12 @@ impl ThreadPool { T: Send + 'a, { #[cfg(not(target_arch = "wasm32"))] - return Guard::new( - self.0 - .pool - .spawn_pinned_unchecked(task) - .map_err(JoinError::into_panic) - .map_err(Panicked::from), - ); + return self + .0 + .pool + .spawn_pinned_unchecked(task) + .map_err(JoinError::into_panic) + .map_err(Panicked::from); #[cfg(target_arch = "wasm32")] { let _self = self; @@ -104,10 +103,10 @@ impl ThreadPool { .map_err(Into::into) .remote_handle(); wasm_bindgen_futures::spawn_local(remote); - Guard::new(remote_handle.map_ok(|t| { + remote_handle.map_ok(|t| { let t: *mut dyn Send = Box::into_raw(t); *Box::from_raw(t as *mut T) - })) + }) } } } @@ -125,39 +124,6 @@ impl Clone for ThreadPool { impl UnwindSafe for ThreadPool {} impl RefUnwindSafe for ThreadPool {} -#[pin_project(PinnedDrop)] -struct Guard(#[pin] Option); -impl Guard { - fn new(f: F) -> Self { - Self(Some(f)) - } -} -impl Future for Guard -where - F: Future, -{ - type Output = F::Output; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - match self.as_mut().project().0.as_pin_mut() { - Some(fut) => { - let output = ready!(fut.poll(cx)); - self.project().0.set(None); - Poll::Ready(output) - } - None => Poll::Pending, - } - } -} -#[pinned_drop] -impl PinnedDrop for Guard { - fn drop(self: Pin<&mut Self>) { - if self.project().0.is_some() { - panic!("dropped before finished polling!"); - } - } -} - fn _assert() { let _ = assert_sync_and_send::; } @@ -166,10 +132,12 @@ fn _assert() { #[cfg(not(target_arch = "wasm32"))] mod pool { use async_channel::{bounded, Sender}; - use futures::{future::RemoteHandle, FutureExt}; + use futures::{ + future::{join_all, RemoteHandle}, FutureExt + }; use std::{any::Any, future::Future, mem, panic::AssertUnwindSafe, pin::Pin}; use tokio::{ - runtime::Handle, task::{JoinError, LocalSet} + runtime::Handle, task, task::{JoinError, JoinHandle, LocalSet} }; type Request = Box Box> + Send>; @@ -177,30 +145,34 @@ mod pool { #[derive(Debug)] pub(super) struct Pool { - sender: Sender<(Request, Sender>)>, + sender: Option>)>>, + threads: Vec>, } impl Pool { pub(super) fn new(threads: usize) -> Self { let handle = Handle::current(); let handle1 = handle.clone(); let (sender, receiver) = bounded::<(Request, Sender>)>(1); - for _ in 0..threads { - let receiver = receiver.clone(); - let handle = handle.clone(); - let _ = handle1.spawn_blocking(move || { - let local = LocalSet::new(); - handle.block_on(local.run_until(async { - while let Ok((task, sender)) = receiver.recv().await { - let _ = local.spawn_local(async move { - let (remote, remote_handle) = Pin::from(task()).remote_handle(); - let _ = sender.send(remote_handle).await; - remote.await; - }); - } - })) - }); - } - Self { sender } + let threads = (0..threads) + .map(|_| { + let receiver = receiver.clone(); + let handle = handle.clone(); + handle1.spawn_blocking(move || { + let local = LocalSet::new(); + handle.block_on(local.run_until(async { + while let Ok((task, sender)) = receiver.recv().await { + let _ = local.spawn_local(async move { + let (remote, remote_handle) = Pin::from(task()).remote_handle(); + let _ = sender.send(remote_handle).await; + remote.await; + }); + } + })) + }) + }) + .collect(); + let sender = Some(sender); + Self { sender, threads } } pub(super) fn spawn_pinned( &self, task: F, @@ -210,7 +182,7 @@ mod pool { Fut: Future + 'static, T: Send + 'static, { - let sender = self.sender.clone(); + let sender = self.sender.as_ref().unwrap().clone(); async move { let task: Request = Box::new(|| { Box::new( @@ -236,7 +208,7 @@ mod pool { Fut: Future + 'a, T: Send + 'a, { - let sender = self.sender.clone(); + let sender = self.sender.as_ref().unwrap().clone(); async move { let task: Box Box> + Send> = Box::new(|| { @@ -264,6 +236,18 @@ mod pool { } } } + impl Drop for Pool { + fn drop(&mut self) { + let _ = self.sender.take().unwrap(); + task::block_in_place(|| { + let handle = Handle::current(); + handle.block_on(join_all(mem::take(&mut self.threads))) + }) + .into_iter() + .collect::>() + .unwrap(); + } + } #[cfg(test)] mod tests {