Skip to content

chore: use Wake more #7342

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 11 additions & 45 deletions tokio-test/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@
//! ```

use std::future::Future;
use std::mem;
use std::ops;
use std::pin::Pin;
use std::sync::{Arc, Condvar, Mutex};
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use std::task::{Context, Poll, Wake, Waker};

use tokio_stream::Stream;

Expand Down Expand Up @@ -171,7 +170,7 @@ impl MockTask {
F: FnOnce(&mut Context<'_>) -> R,
{
self.waker.clear();
let waker = self.waker();
let waker = self.clone().into_waker();
let mut cx = Context::from_waker(&waker);

f(&mut cx)
Expand All @@ -190,11 +189,8 @@ impl MockTask {
Arc::strong_count(&self.waker)
}

fn waker(&self) -> Waker {
unsafe {
let raw = to_raw(self.waker.clone());
Waker::from_raw(raw)
}
fn into_waker(self) -> Waker {
self.waker.into()
}
}

Expand Down Expand Up @@ -226,8 +222,14 @@ impl ThreadWaker {
_ => unreachable!(),
}
}
}

fn wake(&self) {
impl Wake for ThreadWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}

fn wake_by_ref(self: &Arc<Self>) {
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
let mut state = self.state.lock().unwrap();
let prev = *state;
Expand All @@ -247,39 +249,3 @@ impl ThreadWaker {
self.condvar.notify_one();
}
}

static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);

unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
}

unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
Arc::from_raw(raw as *const ThreadWaker)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
let waker = from_raw(raw);

// Increment the ref count
mem::forget(waker.clone());

to_raw(waker)
}

unsafe fn wake(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();
}

unsafe fn wake_by_ref(raw: *const ()) {
let waker = from_raw(raw);
waker.wake();

// We don't actually own a reference to the unparker
mem::forget(waker);
}

unsafe fn drop_waker(raw: *const ()) {
let _ = from_raw(raw);
}
45 changes: 8 additions & 37 deletions tokio/src/runtime/park.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::{Arc, Condvar, Mutex};
use crate::util::{waker, Wake};

use std::sync::atomic::Ordering::SeqCst;
use std::time::Duration;
Expand Down Expand Up @@ -226,7 +227,7 @@ use crate::loom::thread::AccessError;
use std::future::Future;
use std::marker::PhantomData;
use std::rc::Rc;
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::task::Waker;

/// Blocks the current thread using a condition variable.
#[derive(Debug)]
Expand Down Expand Up @@ -292,50 +293,20 @@ impl CachedParkThread {

impl UnparkThread {
pub(crate) fn into_waker(self) -> Waker {
unsafe {
let raw = unparker_to_raw_waker(self.inner);
Waker::from_raw(raw)
}
waker(self.inner)
}
}

impl Inner {
#[allow(clippy::wrong_self_convention)]
fn into_raw(this: Arc<Inner>) -> *const () {
Arc::into_raw(this) as *const ()
impl Wake for Inner {
fn wake(arc_self: Arc<Self>) {
arc_self.unpark();
}

unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
Arc::from_raw(ptr as *const Inner)
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.unpark();
}
}

unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
RawWaker::new(
Inner::into_raw(unparker),
&RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
)
}

unsafe fn clone(raw: *const ()) -> RawWaker {
Arc::increment_strong_count(raw as *const Inner);
unparker_to_raw_waker(Inner::from_raw(raw))
}

unsafe fn drop_waker(raw: *const ()) {
drop(Inner::from_raw(raw));
}

unsafe fn wake(raw: *const ()) {
let unparker = Inner::from_raw(raw);
unparker.unpark();
}

unsafe fn wake_by_ref(raw: *const ()) {
let raw = raw as *const Inner;
(*raw).unpark();
}

#[cfg(loom)]
pub(crate) fn current_thread_park_count() -> usize {
CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))
Expand Down
7 changes: 4 additions & 3 deletions tokio/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ pub(crate) use blocking_check::check_socket_for_blocking;

pub(crate) mod metric_atomics;

mod wake;
pub(crate) use wake::{waker, Wake};

#[cfg(any(
// io driver uses `WakeList` directly
feature = "net",
Expand Down Expand Up @@ -66,9 +69,7 @@ cfg_rt! {

pub(crate) use self::rand::RngSeedGenerator;

mod wake;
pub(crate) use wake::WakerRef;
pub(crate) use wake::{waker_ref, Wake};
pub(crate) use wake::{waker_ref, WakerRef};

mod sync_wrapper;
pub(crate) use sync_wrapper::SyncWrapper;
Expand Down
53 changes: 33 additions & 20 deletions tokio/src/util/wake.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use crate::loom::sync::Arc;

use std::marker::PhantomData;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::task::{RawWaker, RawWakerVTable, Waker};

/// Simplified waking interface based on Arcs.
Expand All @@ -14,30 +12,45 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static {
fn wake_by_ref(arc_self: &Arc<Self>);
}

/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}
cfg_rt! {
use std::marker::PhantomData;
use std::ops::Deref;

/// A `Waker` that is only valid for a given lifetime.
#[derive(Debug)]
pub(crate) struct WakerRef<'a> {
waker: ManuallyDrop<Waker>,
_p: PhantomData<&'a ()>,
}

impl Deref for WakerRef<'_> {
type Target = Waker;
impl Deref for WakerRef<'_> {
type Target = Waker;

fn deref(&self) -> &Waker {
&self.waker
fn deref(&self) -> &Waker {
&self.waker
}
}
}

/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
let ptr = Arc::as_ptr(wake).cast::<()>();

let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };

WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
}
}
}

WakerRef {
waker: ManuallyDrop::new(waker),
_p: PhantomData,
/// Creates a waker from a `Arc<impl Wake>`.
pub(crate) fn waker<W: Wake>(wake: Arc<W>) -> Waker {
unsafe {
Waker::from_raw(RawWaker::new(
Arc::into_raw(wake).cast(),
waker_vtable::<W>(),
))
}
}

Expand Down
Loading