diff --git a/tokio-test/src/task.rs b/tokio-test/src/task.rs index 2e646d44bf8..c781d85ea91 100644 --- a/tokio-test/src/task.rs +++ b/tokio-test/src/task.rs @@ -26,11 +26,10 @@ //! ``` use std::future::Future; -use std::mem; use std::ops; use std::pin::Pin; use std::sync::{Arc, Condvar, Mutex}; -use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; +use std::task::{Context, Poll, Wake, Waker}; use tokio_stream::Stream; @@ -171,7 +170,7 @@ impl MockTask { F: FnOnce(&mut Context<'_>) -> R, { self.waker.clear(); - let waker = self.waker(); + let waker = self.clone().into_waker(); let mut cx = Context::from_waker(&waker); f(&mut cx) @@ -190,11 +189,8 @@ impl MockTask { Arc::strong_count(&self.waker) } - fn waker(&self) -> Waker { - unsafe { - let raw = to_raw(self.waker.clone()); - Waker::from_raw(raw) - } + fn into_waker(self) -> Waker { + self.waker.into() } } @@ -226,8 +222,14 @@ impl ThreadWaker { _ => unreachable!(), } } +} - fn wake(&self) { +impl Wake for ThreadWaker { + fn wake(self: Arc) { + self.wake_by_ref(); + } + + fn wake_by_ref(self: &Arc) { // First, try transitioning from IDLE -> NOTIFY, this does not require a lock. let mut state = self.state.lock().unwrap(); let prev = *state; @@ -247,39 +249,3 @@ impl ThreadWaker { self.condvar.notify_one(); } } - -static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker); - -unsafe fn to_raw(waker: Arc) -> RawWaker { - RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE) -} - -unsafe fn from_raw(raw: *const ()) -> Arc { - Arc::from_raw(raw as *const ThreadWaker) -} - -unsafe fn clone(raw: *const ()) -> RawWaker { - let waker = from_raw(raw); - - // Increment the ref count - mem::forget(waker.clone()); - - to_raw(waker) -} - -unsafe fn wake(raw: *const ()) { - let waker = from_raw(raw); - waker.wake(); -} - -unsafe fn wake_by_ref(raw: *const ()) { - let waker = from_raw(raw); - waker.wake(); - - // We don't actually own a reference to the unparker - mem::forget(waker); -} - -unsafe fn drop_waker(raw: *const ()) { - let _ = from_raw(raw); -} diff --git a/tokio/src/runtime/park.rs b/tokio/src/runtime/park.rs index 08d3e719bc4..27bcd334c45 100644 --- a/tokio/src/runtime/park.rs +++ b/tokio/src/runtime/park.rs @@ -2,6 +2,7 @@ use crate::loom::sync::atomic::AtomicUsize; use crate::loom::sync::{Arc, Condvar, Mutex}; +use crate::util::{waker, Wake}; use std::sync::atomic::Ordering::SeqCst; use std::time::Duration; @@ -226,7 +227,7 @@ use crate::loom::thread::AccessError; use std::future::Future; use std::marker::PhantomData; use std::rc::Rc; -use std::task::{RawWaker, RawWakerVTable, Waker}; +use std::task::Waker; /// Blocks the current thread using a condition variable. #[derive(Debug)] @@ -292,50 +293,20 @@ impl CachedParkThread { impl UnparkThread { pub(crate) fn into_waker(self) -> Waker { - unsafe { - let raw = unparker_to_raw_waker(self.inner); - Waker::from_raw(raw) - } + waker(self.inner) } } -impl Inner { - #[allow(clippy::wrong_self_convention)] - fn into_raw(this: Arc) -> *const () { - Arc::into_raw(this) as *const () +impl Wake for Inner { + fn wake(arc_self: Arc) { + arc_self.unpark(); } - unsafe fn from_raw(ptr: *const ()) -> Arc { - Arc::from_raw(ptr as *const Inner) + fn wake_by_ref(arc_self: &Arc) { + arc_self.unpark(); } } -unsafe fn unparker_to_raw_waker(unparker: Arc) -> RawWaker { - RawWaker::new( - Inner::into_raw(unparker), - &RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker), - ) -} - -unsafe fn clone(raw: *const ()) -> RawWaker { - Arc::increment_strong_count(raw as *const Inner); - unparker_to_raw_waker(Inner::from_raw(raw)) -} - -unsafe fn drop_waker(raw: *const ()) { - drop(Inner::from_raw(raw)); -} - -unsafe fn wake(raw: *const ()) { - let unparker = Inner::from_raw(raw); - unparker.unpark(); -} - -unsafe fn wake_by_ref(raw: *const ()) { - let raw = raw as *const Inner; - (*raw).unpark(); -} - #[cfg(loom)] pub(crate) fn current_thread_park_count() -> usize { CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst)) diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index b57c6acfe97..cf9c7db206b 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -16,6 +16,9 @@ pub(crate) use blocking_check::check_socket_for_blocking; pub(crate) mod metric_atomics; +mod wake; +pub(crate) use wake::{waker, Wake}; + #[cfg(any( // io driver uses `WakeList` directly feature = "net", @@ -66,9 +69,7 @@ cfg_rt! { pub(crate) use self::rand::RngSeedGenerator; - mod wake; - pub(crate) use wake::WakerRef; - pub(crate) use wake::{waker_ref, Wake}; + pub(crate) use wake::{waker_ref, WakerRef}; mod sync_wrapper; pub(crate) use sync_wrapper::SyncWrapper; diff --git a/tokio/src/util/wake.rs b/tokio/src/util/wake.rs index 896ec73e7b1..d583937b8ba 100644 --- a/tokio/src/util/wake.rs +++ b/tokio/src/util/wake.rs @@ -1,8 +1,6 @@ use crate::loom::sync::Arc; -use std::marker::PhantomData; use std::mem::ManuallyDrop; -use std::ops::Deref; use std::task::{RawWaker, RawWakerVTable, Waker}; /// Simplified waking interface based on Arcs. @@ -14,30 +12,45 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static { fn wake_by_ref(arc_self: &Arc); } -/// A `Waker` that is only valid for a given lifetime. -#[derive(Debug)] -pub(crate) struct WakerRef<'a> { - waker: ManuallyDrop, - _p: PhantomData<&'a ()>, -} +cfg_rt! { + use std::marker::PhantomData; + use std::ops::Deref; + + /// A `Waker` that is only valid for a given lifetime. + #[derive(Debug)] + pub(crate) struct WakerRef<'a> { + waker: ManuallyDrop, + _p: PhantomData<&'a ()>, + } -impl Deref for WakerRef<'_> { - type Target = Waker; + impl Deref for WakerRef<'_> { + type Target = Waker; - fn deref(&self) -> &Waker { - &self.waker + fn deref(&self) -> &Waker { + &self.waker + } } -} -/// Creates a reference to a `Waker` from a reference to `Arc`. -pub(crate) fn waker_ref(wake: &Arc) -> WakerRef<'_> { - let ptr = Arc::as_ptr(wake).cast::<()>(); + /// Creates a reference to a `Waker` from a reference to `Arc`. + pub(crate) fn waker_ref(wake: &Arc) -> WakerRef<'_> { + let ptr = Arc::as_ptr(wake).cast::<()>(); - let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) }; + let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::())) }; + + WakerRef { + waker: ManuallyDrop::new(waker), + _p: PhantomData, + } + } +} - WakerRef { - waker: ManuallyDrop::new(waker), - _p: PhantomData, +/// Creates a waker from a `Arc`. +pub(crate) fn waker(wake: Arc) -> Waker { + unsafe { + Waker::from_raw(RawWaker::new( + Arc::into_raw(wake).cast(), + waker_vtable::(), + )) } }