diff --git a/maitake/src/sync/wait_cell.rs b/maitake/src/sync/wait_cell.rs index 1ff6674d..63234a3c 100644 --- a/maitake/src/sync/wait_cell.rs +++ b/maitake/src/sync/wait_cell.rs @@ -73,9 +73,6 @@ pub enum RegisterError { pub struct Wait<'a> { /// The [`WaitCell`] being waited on. cell: &'a WaitCell, - - /// Whether we have already polled once - registered: bool, } #[derive(Eq, PartialEq, Copy, Clone)] @@ -125,16 +122,12 @@ impl WaitCell { Err(actual) if test_dbg!(actual.is(State::CLOSED)) => { return Err(RegisterError::Closed); } - Err(actual) if test_dbg!(actual.is(State::WAKING)) => { + Err(actual) + if test_dbg!(actual.is(State::WAKING)) || test_dbg!(actual.is(State::WOKEN)) => + { return Err(RegisterError::Waking); } - - Err(actual) => { - debug_assert!( - actual == State::REGISTERING || actual == State::REGISTERING | State::WAKING - ); - return Err(RegisterError::Registering); - } + Err(_) => return Err(RegisterError::Registering), Ok(_) => {} } @@ -192,10 +185,7 @@ impl WaitCell { /// **Note**: The calling task's [`Waker`] is not registered until AFTER the /// first time the returned [`Wait`] future is polled. pub fn wait(&self) -> Wait<'_> { - Wait { - cell: self, - registered: false, - } + Wait { cell: self } } /// Wake the [`Waker`] stored in this cell. @@ -242,7 +232,7 @@ impl WaitCell { // TODO(eliza): could probably be made a public API... pub(crate) fn take_waker(&self, close: bool) -> Option { trace!(wait_cell = ?fmt::ptr(self), ?close, "notifying"); - let mut bits = State::WAKING; + let mut bits = State::WAKING | State::WOKEN; if close { bits.0 |= State::CLOSED.0; } @@ -314,17 +304,15 @@ impl Drop for WaitCell { impl Future for Wait<'_> { type Output = Result<(), super::Closed>; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.registered { - // We made it to "once", and got polled again, we must be ready! + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // Try to take the cell's `WOKEN` bit to see if we were previously + // waiting and then received a notification. + if test_dbg!(self.cell.fetch_and(!State::WOKEN, AcqRel)).is(State::WOKEN) { return Poll::Ready(Ok(())); } match test_dbg!(self.cell.register_wait(cx.waker())) { - Ok(_) => { - self.registered = true; - Poll::Pending - } + Ok(_) => Poll::Pending, Err(RegisterError::Registering) => { // Cell was busy parking some other task, all we can do is try again later cx.waker().wake_by_ref(); @@ -347,6 +335,7 @@ impl State { const REGISTERING: Self = Self(0b01); const WAKING: Self = Self(0b10); const CLOSED: Self = Self(0b100); + const WOKEN: Self = Self(0b1000); fn is(self, Self(state): Self) -> bool { self.0 & state == state @@ -373,7 +362,7 @@ impl fmt::Debug for State { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut has_states = false; - fmt_bits!(self, f, has_states, REGISTERING, WAKING, CLOSED); + fmt_bits!(self, f, has_states, REGISTERING, WAKING, CLOSED, WOKEN); if !has_states { if *self == Self::WAITING { @@ -395,9 +384,12 @@ mod tests { use crate::scheduler::Scheduler; use alloc::sync::Arc; + use tokio_test::{assert_pending, assert_ready_ok, task}; + #[test] fn wait_smoke() { static COMPLETED: AtomicUsize = AtomicUsize::new(0); + let _trace = crate::util::test::trace_init(); let sched = Scheduler::new(); let wait = Arc::new(WaitCell::new()); @@ -417,6 +409,25 @@ mod tests { assert_eq!(tick.completed, 1); assert_eq!(COMPLETED.load(Ordering::Relaxed), 1); } + + /// Reproduces https://github.com/hawkw/mycelium/issues/449 + #[test] + fn wait_spurious_poll() { + let _trace = crate::util::test::trace_init(); + + let cell = Arc::new(WaitCell::new()); + let mut task = task::spawn({ + let cell = cell.clone(); + async move { cell.wait().await } + }); + + assert_pending!(task.poll(), "first poll should be pending"); + assert_pending!(task.poll(), "second poll should be pending"); + + cell.wake(); + + assert_ready_ok!(task.poll(), "should have been woken"); + } } #[cfg(test)]