diff --git a/async/src/halves.rs b/async/src/halves.rs index a1a5090..13f820b 100644 --- a/async/src/halves.rs +++ b/async/src/halves.rs @@ -149,7 +149,8 @@ where self.base.rb().is_closed() } fn close(&self) { - self.base.rb().close() + self.base.rb().close(); + self.base.rb().wake_consumer(); } } impl AsyncProducer for AsyncProd @@ -181,7 +182,8 @@ where self.base.rb().is_closed() } fn close(&self) { - self.base.rb().close() + self.base.rb().close(); + self.base.rb().wake_producer(); } } impl AsyncConsumer for AsyncCons diff --git a/async/src/rb.rs b/async/src/rb.rs index 06f8eed..79c102b 100644 --- a/async/src/rb.rs +++ b/async/src/rb.rs @@ -81,7 +81,14 @@ impl AsyncConsumer for AsyncRb { self.write.register(waker); } } -impl AsyncRingBuffer for AsyncRb {} +impl AsyncRingBuffer for AsyncRb { + fn wake_consumer(&self) { + self.write.wake() + } + fn wake_producer(&self) { + self.read.wake() + } +} impl SplitRef for AsyncRb { type RefProd<'a> = AsyncProd<&'a Self> where Self: 'a; diff --git a/async/src/tests.rs b/async/src/tests.rs index 3bc2159..f96e0d4 100644 --- a/async/src/tests.rs +++ b/async/src/tests.rs @@ -1,6 +1,7 @@ use crate::{async_transfer, traits::*, AsyncHeapRb}; use core::sync::atomic::{AtomicUsize, Ordering}; use futures::task::{noop_waker_ref, AtomicWaker}; +use std::sync::Arc; use std::{vec, vec::Vec}; #[test] @@ -168,3 +169,52 @@ fn wait() { }, ); } + +#[test] +fn drop_close_prod() { + let (prod, mut cons) = AsyncHeapRb::::new(1).split(); + let stage = Arc::new(AtomicUsize::new(0)); + let stage_clone = stage.clone(); + let t0 = std::thread::spawn(move || { + execute!(async { + drop(prod); + assert_eq!(stage.fetch_add(1, Ordering::SeqCst), 0); + }); + }); + let t1 = std::thread::spawn(move || { + execute!(async { + cons.wait_occupied(1).await; + assert_eq!(stage_clone.fetch_add(1, Ordering::SeqCst), 1); + assert!(cons.is_closed()); + }); + }); + t0.join().unwrap(); + t1.join().unwrap(); +} + +#[test] +fn drop_close_cons() { + let (mut prod, mut cons) = AsyncHeapRb::::new(1).split(); + let stage = Arc::new(AtomicUsize::new(0)); + let stage_clone = stage.clone(); + let t0 = std::thread::spawn(move || { + execute!(async { + prod.push(0).await.unwrap(); + assert_eq!(stage.fetch_add(1, Ordering::SeqCst), 0); + + prod.wait_vacant(1).await; + assert_eq!(stage.fetch_add(1, Ordering::SeqCst), 3); + assert!(prod.is_closed()); + }); + }); + let t1 = std::thread::spawn(move || { + execute!(async { + cons.wait_occupied(1).await; + assert_eq!(stage_clone.fetch_add(1, Ordering::SeqCst), 1); + drop(cons); + assert_eq!(stage_clone.fetch_add(1, Ordering::SeqCst), 2); + }); + }); + t0.join().unwrap(); + t1.join().unwrap(); +} diff --git a/async/src/traits/mod.rs b/async/src/traits/mod.rs index a3c48a6..aab30cd 100644 --- a/async/src/traits/mod.rs +++ b/async/src/traits/mod.rs @@ -1,11 +1,11 @@ pub mod consumer; pub mod observer; pub mod producer; +pub mod ring_buffer; pub use consumer::AsyncConsumer; pub use observer::AsyncObserver; pub use producer::AsyncProducer; - -pub trait AsyncRingBuffer: ringbuf::traits::RingBuffer + AsyncProducer + AsyncConsumer {} +pub use ring_buffer::AsyncRingBuffer; pub use ringbuf::traits::*; diff --git a/async/src/traits/ring_buffer.rs b/async/src/traits/ring_buffer.rs index e69de29..8718a4d 100644 --- a/async/src/traits/ring_buffer.rs +++ b/async/src/traits/ring_buffer.rs @@ -0,0 +1,8 @@ +use crate::consumer::AsyncConsumer; +use crate::producer::AsyncProducer; +use ringbuf::traits::RingBuffer; + +pub trait AsyncRingBuffer: RingBuffer + AsyncProducer + AsyncConsumer { + fn wake_producer(&self); + fn wake_consumer(&self); +} diff --git a/src/transfer.rs b/src/transfer.rs index d284e3d..19f4da4 100644 --- a/src/transfer.rs +++ b/src/transfer.rs @@ -6,11 +6,7 @@ use crate::{consumer::Consumer, producer::Producer}; /// `count` is the number of items being moved, if `None` - as much as possible items will be moved. /// /// Returns number of items been moved. -pub fn transfer, P: Producer>( - src: &mut C, - dst: &mut P, - count: Option, -) -> usize { +pub fn transfer, P: Producer>(src: &mut C, dst: &mut P, count: Option) -> usize { let (src_left, src_right) = src.occupied_slices(); let (dst_left, dst_right) = dst.vacant_slices_mut(); let src_iter = src_left.iter().chain(src_right.iter());