Skip to content

Commit 20874fe

Browse files
committed
chore: replace RawWaker with Wake
`Wake` allows for creating a `Waker` without unsafe code.
1 parent 0cf95f0 commit 20874fe

File tree

4 files changed

+56
-105
lines changed

4 files changed

+56
-105
lines changed

tokio-test/src/task.rs

Lines changed: 11 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
//! ```
2727
2828
use std::future::Future;
29-
use std::mem;
3029
use std::ops;
3130
use std::pin::Pin;
3231
use std::sync::{Arc, Condvar, Mutex};
33-
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
32+
use std::task::{Context, Poll, Wake, Waker};
3433

3534
use tokio_stream::Stream;
3635

@@ -171,7 +170,7 @@ impl MockTask {
171170
F: FnOnce(&mut Context<'_>) -> R,
172171
{
173172
self.waker.clear();
174-
let waker = self.waker();
173+
let waker = self.clone().into_waker();
175174
let mut cx = Context::from_waker(&waker);
176175

177176
f(&mut cx)
@@ -190,11 +189,8 @@ impl MockTask {
190189
Arc::strong_count(&self.waker)
191190
}
192191

193-
fn waker(&self) -> Waker {
194-
unsafe {
195-
let raw = to_raw(self.waker.clone());
196-
Waker::from_raw(raw)
197-
}
192+
fn into_waker(self) -> Waker {
193+
self.waker.into()
198194
}
199195
}
200196

@@ -226,8 +222,14 @@ impl ThreadWaker {
226222
_ => unreachable!(),
227223
}
228224
}
225+
}
229226

230-
fn wake(&self) {
227+
impl Wake for ThreadWaker {
228+
fn wake(self: Arc<Self>) {
229+
self.wake_by_ref();
230+
}
231+
232+
fn wake_by_ref(self: &Arc<Self>) {
231233
// First, try transitioning from IDLE -> NOTIFY, this does not require a lock.
232234
let mut state = self.state.lock().unwrap();
233235
let prev = *state;
@@ -247,39 +249,3 @@ impl ThreadWaker {
247249
self.condvar.notify_one();
248250
}
249251
}
250-
251-
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker);
252-
253-
unsafe fn to_raw(waker: Arc<ThreadWaker>) -> RawWaker {
254-
RawWaker::new(Arc::into_raw(waker) as *const (), &VTABLE)
255-
}
256-
257-
unsafe fn from_raw(raw: *const ()) -> Arc<ThreadWaker> {
258-
Arc::from_raw(raw as *const ThreadWaker)
259-
}
260-
261-
unsafe fn clone(raw: *const ()) -> RawWaker {
262-
let waker = from_raw(raw);
263-
264-
// Increment the ref count
265-
mem::forget(waker.clone());
266-
267-
to_raw(waker)
268-
}
269-
270-
unsafe fn wake(raw: *const ()) {
271-
let waker = from_raw(raw);
272-
waker.wake();
273-
}
274-
275-
unsafe fn wake_by_ref(raw: *const ()) {
276-
let waker = from_raw(raw);
277-
waker.wake();
278-
279-
// We don't actually own a reference to the unparker
280-
mem::forget(waker);
281-
}
282-
283-
unsafe fn drop_waker(raw: *const ()) {
284-
let _ = from_raw(raw);
285-
}

tokio/src/runtime/park.rs

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use crate::loom::sync::atomic::AtomicUsize;
44
use crate::loom::sync::{Arc, Condvar, Mutex};
5+
use crate::util::{waker, Wake};
56

67
use std::sync::atomic::Ordering::SeqCst;
78
use std::time::Duration;
@@ -226,7 +227,7 @@ use crate::loom::thread::AccessError;
226227
use std::future::Future;
227228
use std::marker::PhantomData;
228229
use std::rc::Rc;
229-
use std::task::{RawWaker, RawWakerVTable, Waker};
230+
use std::task::Waker;
230231

231232
/// Blocks the current thread using a condition variable.
232233
#[derive(Debug)]
@@ -292,50 +293,20 @@ impl CachedParkThread {
292293

293294
impl UnparkThread {
294295
pub(crate) fn into_waker(self) -> Waker {
295-
unsafe {
296-
let raw = unparker_to_raw_waker(self.inner);
297-
Waker::from_raw(raw)
298-
}
296+
waker(self.inner)
299297
}
300298
}
301299

302-
impl Inner {
303-
#[allow(clippy::wrong_self_convention)]
304-
fn into_raw(this: Arc<Inner>) -> *const () {
305-
Arc::into_raw(this) as *const ()
300+
impl Wake for Inner {
301+
fn wake(arc_self: Arc<Self>) {
302+
arc_self.unpark();
306303
}
307304

308-
unsafe fn from_raw(ptr: *const ()) -> Arc<Inner> {
309-
Arc::from_raw(ptr as *const Inner)
305+
fn wake_by_ref(arc_self: &Arc<Self>) {
306+
arc_self.unpark();
310307
}
311308
}
312309

313-
unsafe fn unparker_to_raw_waker(unparker: Arc<Inner>) -> RawWaker {
314-
RawWaker::new(
315-
Inner::into_raw(unparker),
316-
&RawWakerVTable::new(clone, wake, wake_by_ref, drop_waker),
317-
)
318-
}
319-
320-
unsafe fn clone(raw: *const ()) -> RawWaker {
321-
Arc::increment_strong_count(raw as *const Inner);
322-
unparker_to_raw_waker(Inner::from_raw(raw))
323-
}
324-
325-
unsafe fn drop_waker(raw: *const ()) {
326-
drop(Inner::from_raw(raw));
327-
}
328-
329-
unsafe fn wake(raw: *const ()) {
330-
let unparker = Inner::from_raw(raw);
331-
unparker.unpark();
332-
}
333-
334-
unsafe fn wake_by_ref(raw: *const ()) {
335-
let raw = raw as *const Inner;
336-
(*raw).unpark();
337-
}
338-
339310
#[cfg(loom)]
340311
pub(crate) fn current_thread_park_count() -> usize {
341312
CURRENT_THREAD_PARK_COUNT.with(|count| count.load(SeqCst))

tokio/src/util/mod.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ pub(crate) use blocking_check::check_socket_for_blocking;
1616

1717
pub(crate) mod metric_atomics;
1818

19+
mod wake;
20+
pub(crate) use wake::{waker, Wake};
21+
1922
#[cfg(any(
2023
// io driver uses `WakeList` directly
2124
feature = "net",
@@ -66,9 +69,7 @@ cfg_rt! {
6669

6770
pub(crate) use self::rand::RngSeedGenerator;
6871

69-
mod wake;
70-
pub(crate) use wake::WakerRef;
71-
pub(crate) use wake::{waker_ref, Wake};
72+
pub(crate) use wake::{waker_ref, WakerRef};
7273

7374
mod sync_wrapper;
7475
pub(crate) use sync_wrapper::SyncWrapper;

tokio/src/util/wake.rs

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use crate::loom::sync::Arc;
22

3-
use std::marker::PhantomData;
43
use std::mem::ManuallyDrop;
5-
use std::ops::Deref;
64
use std::task::{RawWaker, RawWakerVTable, Waker};
75

86
/// Simplified waking interface based on Arcs.
@@ -14,30 +12,45 @@ pub(crate) trait Wake: Send + Sync + Sized + 'static {
1412
fn wake_by_ref(arc_self: &Arc<Self>);
1513
}
1614

17-
/// A `Waker` that is only valid for a given lifetime.
18-
#[derive(Debug)]
19-
pub(crate) struct WakerRef<'a> {
20-
waker: ManuallyDrop<Waker>,
21-
_p: PhantomData<&'a ()>,
22-
}
15+
cfg_rt! {
16+
use std::marker::PhantomData;
17+
use std::ops::Deref;
18+
19+
/// A `Waker` that is only valid for a given lifetime.
20+
#[derive(Debug)]
21+
pub(crate) struct WakerRef<'a> {
22+
waker: ManuallyDrop<Waker>,
23+
_p: PhantomData<&'a ()>,
24+
}
2325

24-
impl Deref for WakerRef<'_> {
25-
type Target = Waker;
26+
impl Deref for WakerRef<'_> {
27+
type Target = Waker;
2628

27-
fn deref(&self) -> &Waker {
28-
&self.waker
29+
fn deref(&self) -> &Waker {
30+
&self.waker
31+
}
2932
}
30-
}
3133

32-
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
33-
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
34-
let ptr = Arc::as_ptr(wake).cast::<()>();
34+
/// Creates a reference to a `Waker` from a reference to `Arc<impl Wake>`.
35+
pub(crate) fn waker_ref<W: Wake>(wake: &Arc<W>) -> WakerRef<'_> {
36+
let ptr = Arc::as_ptr(wake).cast::<()>();
3537

36-
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
38+
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr, waker_vtable::<W>())) };
39+
40+
WakerRef {
41+
waker: ManuallyDrop::new(waker),
42+
_p: PhantomData,
43+
}
44+
}
45+
}
3746

38-
WakerRef {
39-
waker: ManuallyDrop::new(waker),
40-
_p: PhantomData,
47+
/// Creates a waker from a `Arc<impl Wake>`.
48+
pub(crate) fn waker<W: Wake>(wake: Arc<W>) -> Waker {
49+
unsafe {
50+
Waker::from_raw(RawWaker::new(
51+
Arc::into_raw(wake).cast(),
52+
waker_vtable::<W>(),
53+
))
4154
}
4255
}
4356

0 commit comments

Comments
 (0)