From 161e48501b50e61f1e4495ede9b02964881214ca Mon Sep 17 00:00:00 2001 From: alecmocatta Date: Sat, 15 Aug 2020 13:28:38 +0100 Subject: [PATCH] tests mostly passing --- Cargo.toml | 3 + src/lib.rs | 9 +- src/pool/thread.rs | 286 +++++++++++++++++++++++++++++---------------- 3 files changed, 193 insertions(+), 105 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 690140b1..06aab52d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,9 @@ wasm-bindgen-test = "0.3" [build-dependencies] rustversion = "1.0" +[patch.crates-io] +tokio = {git = "https://github.com/tokio-rs/tokio", branch = "v0.2.x"} + [profile.bench] codegen-units = 1 debug = 2 diff --git a/src/lib.rs b/src/lib.rs index 498a34b2..d45733b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,12 +21,13 @@ clippy::pedantic, )] #![allow( - clippy::module_name_repetitions, - clippy::similar_names, clippy::if_not_else, - clippy::must_use_candidate, + clippy::inline_always, clippy::missing_errors_doc, - clippy::missing_safety_doc + clippy::missing_safety_doc, + clippy::module_name_repetitions, + clippy::must_use_candidate, + clippy::similar_names )] #![deny(unsafe_code)] diff --git a/src/pool/thread.rs b/src/pool/thread.rs index c917d9f9..ef3c8a76 100644 --- a/src/pool/thread.rs +++ b/src/pool/thread.rs @@ -1,5 +1,6 @@ -use futures::{ready, TryFutureExt}; -use pin_project::{pin_project, pinned_drop}; +use derive_new::new; +use futures::TryFutureExt; +use pin_project::pin_project; use std::{ future::Future, io, panic::{RefUnwindSafe, UnwindSafe}, pin::Pin, sync::Arc, task::{Context, Poll} }; @@ -45,19 +46,22 @@ impl ThreadPool { pub fn threads(&self) -> usize { self.0.logical_cores * self.0.tasks_per_core } - pub fn spawn(&self, task: F) -> impl Future> + Send + pub fn spawn( + &self, task: F, + ) -> JoinGuard> + Send> where F: FnOnce() -> Fut + Send + 'static, Fut: Future + 'static, T: Send + 'static, { #[cfg(not(target_arch = "wasm32"))] - return self - .0 - .pool - .spawn_pinned(task) - .map_err(JoinError::into_panic) - .map_err(Panicked::from); + return JoinGuard::new( + self.0 + .pool + .spawn_pinned(task) + .map_err(JoinError::into_panic) + .map_err(Panicked::from), + ); #[cfg(target_arch = "wasm32")] { let _self = self; @@ -66,20 +70,20 @@ impl ThreadPool { .map_err(Into::into) .remote_handle(); wasm_bindgen_futures::spawn_local(remote); - remote_handle + JoinGuard::new(remote_handle) } } #[allow(unsafe_code)] pub unsafe fn spawn_unchecked<'a, F, Fut, T>( &self, task: F, - ) -> impl Future> + Send + 'a + ) -> JoinGuard> + Send + 'a> where F: FnOnce() -> Fut + Send + 'a, Fut: Future + 'a, T: Send + 'a, { #[cfg(not(target_arch = "wasm32"))] - return Guard::new( + return JoinGuard::new( self.0 .pool .spawn_pinned_unchecked(task) @@ -104,7 +108,7 @@ impl ThreadPool { .map_err(Into::into) .remote_handle(); wasm_bindgen_futures::spawn_local(remote); - Guard::new(remote_handle.map_ok(|t| { + JoinGuard::new(remote_handle.map_ok(|t| { let t: *mut dyn Send = Box::into_raw(t); *Box::from_raw(t as *mut T) })) @@ -125,36 +129,30 @@ impl Clone for ThreadPool { impl UnwindSafe for ThreadPool {} impl RefUnwindSafe for ThreadPool {} -#[pin_project(PinnedDrop)] -struct Guard(#[pin] Option); -impl Guard { - fn new(f: F) -> Self { - Self(Some(f)) +#[pin_project] +#[derive(new)] +pub struct JoinGuard(#[pin] F) +where + F: Future; +#[cfg(not(target_arch = "wasm32"))] +impl JoinGuard> +where + F: Future, +{ + #[inline(always)] + pub fn cancel<'a>(self: Pin<&'a mut Self>) -> impl Future + 'a { + self.project().0.cancel() } } -impl Future for Guard +impl Future for JoinGuard where F: Future, { type Output = F::Output; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - match self.as_mut().project().0.as_pin_mut() { - Some(fut) => { - let output = ready!(fut.poll(cx)); - self.project().0.set(None); - Poll::Ready(output) - } - None => Poll::Pending, - } - } -} -#[pinned_drop] -impl PinnedDrop for Guard { - fn drop(self: Pin<&mut Self>) { - if self.project().0.is_some() { - panic!("dropped before finished polling!"); - } + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.project().0.poll(cx) } } @@ -166,101 +164,187 @@ fn _assert() { #[cfg(not(target_arch = "wasm32"))] mod pool { use async_channel::{bounded, Sender}; - use futures::{future::RemoteHandle, FutureExt}; - use std::{any::Any, future::Future, mem, panic::AssertUnwindSafe, pin::Pin}; + use futures::{ + future::{join_all, AbortHandle, Abortable, Aborted, Fuse, FusedFuture, RemoteHandle}, FutureExt + }; + use pin_project::{pin_project, pinned_drop}; + use std::{ + any::Any, future::Future, mem, panic::AssertUnwindSafe, pin::Pin, task::{Context, Poll} + }; use tokio::{ - runtime::Handle, task::{JoinError, LocalSet} + runtime::Handle, task, task::{JoinError, JoinHandle, LocalSet} }; type Request = Box Box> + Send>; - type Response = Result, Box>; + type Response = Result, Box>, Aborted>; #[derive(Debug)] pub(super) struct Pool { - sender: Sender<(Request, Sender>)>, + sender: Option>)>>, + threads: Vec>, } impl Pool { pub(super) fn new(threads: usize) -> Self { let handle = Handle::current(); let handle1 = handle.clone(); let (sender, receiver) = bounded::<(Request, Sender>)>(1); - for _ in 0..threads { - let receiver = receiver.clone(); - let handle = handle.clone(); - let _ = handle1.spawn_blocking(move || { - let local = LocalSet::new(); - handle.block_on(local.run_until(async { - while let Ok((task, sender)) = receiver.recv().await { - let _ = local.spawn_local(async move { - let (remote, remote_handle) = Pin::from(task()).remote_handle(); - let _ = sender.send(remote_handle).await; - remote.await; - }); - } - })) - }); - } - Self { sender } + let threads = (0..threads) + .map(|_| { + let receiver = receiver.clone(); + let handle = handle.clone(); + handle1.spawn_blocking(move || { + let local = LocalSet::new(); + handle.block_on(local.run_until(async { + while let Ok((task, sender)) = receiver.recv().await { + let _ = local.spawn_local(async move { + let (remote, remote_handle) = Pin::from(task()).remote_handle(); + let _ = sender.send(remote_handle).await; + remote.await; + }); + } + })) + }) + }) + .collect(); + let sender = Some(sender); + Self { sender, threads } } pub(super) fn spawn_pinned( &self, task: F, - ) -> impl Future> + Send + ) -> JoinGuard> + Send> where F: FnOnce() -> Fut + Send + 'static, Fut: Future + 'static, T: Send + 'static, { - let sender = self.sender.clone(); - async move { - let task: Request = Box::new(|| { - Box::new( - AssertUnwindSafe(task().map(|t| Box::new(t) as Box)) - .catch_unwind(), - ) - }); - let (sender_, receiver) = bounded::>(1); - sender.send((task, sender_)).await.unwrap(); - let res = receiver.recv().await; - let res = res.unwrap().await; - #[allow(deprecated)] - res.map(|x| *Box::::downcast(x).unwrap()) - .map_err(JoinError::panic) - } + let sender = self.sender.as_ref().unwrap().clone(); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + JoinGuard::new( + async move { + let task: Request = Box::new(|| { + Box::new(Abortable::new( + AssertUnwindSafe(task().map(|t| Box::new(t) as Box)) + .catch_unwind(), + abort_registration, + )) + }); + let (sender_, receiver) = bounded::>(1); + sender.send((task, sender_)).await.unwrap(); + let res = receiver.recv().await; + let res = res.unwrap().await; + #[allow(deprecated)] + match res { + Ok(Ok(res)) => Ok(*Box::::downcast(res).unwrap()), + Ok(Err(panic)) => Err(JoinError::panic(panic)), + Err(Aborted) => Err(JoinError::cancelled()), + } + }, + abort_handle, + ) } #[allow(unsafe_code)] pub(super) unsafe fn spawn_pinned_unchecked<'a, F, Fut, T>( &self, task: F, - ) -> impl Future> + Send + 'a + ) -> JoinGuard> + Send + 'a> where F: FnOnce() -> Fut + Send + 'a, Fut: Future + 'a, T: Send + 'a, { - let sender = self.sender.clone(); - async move { - let task: Box Box> + Send> = - Box::new(|| { - Box::new( - AssertUnwindSafe(task().map(|t| { - let t: Box = Box::new(t); - let t: Box = mem::transmute(t); - t - })) - .catch_unwind(), - ) - }); - let task: Box Box> + Send> = - mem::transmute(task); - let (sender_, receiver) = bounded::>(1); - sender.send((task, sender_)).await.unwrap(); - let res = receiver.recv().await; - let res = res.unwrap().await; - #[allow(deprecated)] - res.map(|t| { - let t: *mut dyn Any = Box::into_raw(t); - *Box::from_raw(t as *mut T) - }) - .map_err(JoinError::panic) + let sender = self.sender.as_ref().unwrap().clone(); + let (abort_handle, abort_registration) = AbortHandle::new_pair(); + JoinGuard::new( + async move { + let task: Box Box> + Send> = + Box::new(|| { + Box::new(Abortable::new( + AssertUnwindSafe(task().map(|t| { + let t: Box = Box::new(t); + let t: Box = mem::transmute(t); + t + })) + .catch_unwind(), + abort_registration, + )) + }); + let task: Box Box> + Send> = + mem::transmute(task); + let (sender_, receiver) = bounded::>(1); + sender.send((task, sender_)).await.unwrap(); + let res = receiver.recv().await; + let res = res.unwrap().await; + #[allow(deprecated)] + match res { + Ok(Ok(res)) => { + let t: *mut dyn Any = Box::into_raw(res); + Ok(*Box::from_raw(t as *mut T)) + } + Ok(Err(panic)) => Err(JoinError::panic(panic)), + Err(Aborted) => Err(JoinError::cancelled()), + } + }, + abort_handle, + ) + } + } + impl Drop for Pool { + fn drop(&mut self) { + let _ = self.sender.take().unwrap(); + task::block_in_place(|| { + Handle::current().block_on(join_all(mem::take(&mut self.threads))) + }) + .into_iter() + .collect::>() + .unwrap(); + } + } + + #[pin_project(PinnedDrop)] + pub(crate) struct JoinGuard(#[pin] Fuse, AbortHandle) + where + F: Future; + impl JoinGuard + where + F: Future, + { + #[inline(always)] + fn new(f: F, abort_handle: AbortHandle) -> Self { + Self(f.fuse(), abort_handle) + } + #[inline(always)] + pub(crate) fn cancel<'a>(mut self: Pin<&'a mut Self>) -> impl Future + 'a { + futures::future::poll_fn(move |cx| { + self.1.abort(); + let self_ = self.as_mut().project().0; + if !self_.is_terminated() { + self_.poll(cx).map(drop) + } else { + Poll::Ready(()) + } + }) + } + } + impl Future for JoinGuard + where + F: Future, + { + type Output = F::Output; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + self.project().0.poll(cx) + } + } + #[pinned_drop] + impl PinnedDrop for JoinGuard + where + F: Future, + { + fn drop(self: Pin<&mut Self>) { + let self_ = self.project(); + self_.1.abort(); + if !self_.0.is_terminated() { + let _ = task::block_in_place(|| Handle::current().block_on(self_.0)); } } }