From 6cce703064e027474acc1be723f523c5890a198d Mon Sep 17 00:00:00 2001 From: black-binary Date: Thu, 17 Dec 2020 20:45:06 +0000 Subject: [PATCH] Add poll_send method for Sender --- .github/workflows/build-and-test.yaml | 4 +- .github/workflows/lint.yaml | 2 +- .github/workflows/security.yaml | 2 +- Cargo.toml | 6 +- src/lib.rs | 104 +++++++++++++++++++++++++ tests/bounded.rs | 105 ++++++++++++++++++++++++++ tests/unbounded.rs | 105 ++++++++++++++++++++++++++ 7 files changed, 321 insertions(+), 7 deletions(-) diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml index 19d1ebb..a6cac3c 100644 --- a/.github/workflows/build-and-test.yaml +++ b/.github/workflows/build-and-test.yaml @@ -19,11 +19,11 @@ jobs: - name: Set current week of the year in environnement if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macOS') - run: echo "::set-env name=CURRENT_WEEK::$(date +%V)" + run: echo "CURRENT_WEEK=$(date +%V)" >> $GITHUB_ENV - name: Set current week of the year in environnement if: startsWith(matrix.os, 'windows') - run: echo "::set-env name=CURRENT_WEEK::$(Get-Date -UFormat %V)" + run: echo "::set-env name=CURRENT_WEEK::$(date +%V)" - name: Install latest ${{ matrix.rust }} uses: actions-rs/toolchain@v1 diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 7e9bd98..928c4aa 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v2 - name: Set current week of the year in environnement - run: echo "::set-env name=CURRENT_WEEK::$(date +%V)" + run: echo "CURRENT_WEEK=$(date +%V)" >> $GITHUB_ENV - uses: actions-rs/toolchain@v1 with: diff --git a/.github/workflows/security.yaml b/.github/workflows/security.yaml index 8f722e7..672e506 100644 --- a/.github/workflows/security.yaml +++ b/.github/workflows/security.yaml @@ -13,7 +13,7 @@ jobs: - uses: actions/checkout@v2 - name: Set current week of the year in environnement - run: echo "::set-env name=CURRENT_WEEK::$(date +%V)" + run: echo "CURRENT_WEEK=$(date +%V)" >> $GITHUB_ENV - uses: actions-rs/audit-check@v1 with: diff --git a/Cargo.toml b/Cargo.toml index a71fa02..e4dd41c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,10 @@ readme = "README.md" [dependencies] concurrent-queue = "1.2.2" -event-listener = "2.4.0" -futures-core = "0.3.5" +event-listener = "2.5.1" +futures-core = "0.3.8" [dev-dependencies] blocking = "0.6.0" easy-parallel = "3.1.0" -futures-lite = "1.11.0" +futures-lite = "1.11.2" diff --git a/src/lib.rs b/src/lib.rs index af36dec..160a339 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,9 @@ struct Channel { /// Stream operations while the channel is empty and not closed. stream_ops: Event, + /// Sink operations while the channel is empty and not closed. + sink_ops: Event, + /// The number of currently active `Sender`s. sender_count: AtomicUsize, @@ -112,12 +115,15 @@ pub fn bounded(cap: usize) -> (Sender, Receiver) { send_ops: Event::new(), recv_ops: Event::new(), stream_ops: Event::new(), + sink_ops: Event::new(), sender_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1), }); let s = Sender { channel: channel.clone(), + listener: None, + sending_msg: None, }; let r = Receiver { channel, @@ -151,12 +157,15 @@ pub fn unbounded() -> (Sender, Receiver) { send_ops: Event::new(), recv_ops: Event::new(), stream_ops: Event::new(), + sink_ops: Event::new(), sender_count: AtomicUsize::new(1), receiver_count: AtomicUsize::new(1), }); let s = Sender { channel: channel.clone(), + listener: None, + sending_msg: None, }; let r = Receiver { channel, @@ -174,6 +183,11 @@ pub fn unbounded() -> (Sender, Receiver) { pub struct Sender { /// Inner channel state. channel: Arc>, + + /// Listens for a recv or close event to unblock this stream. + listener: Option, + + sending_msg: Option, } impl Sender { @@ -421,6 +435,91 @@ impl Sender { pub fn sender_count(&self) -> usize { self.channel.sender_count.load(Ordering::SeqCst) } + + /// Attempts to send a message into the channel. + /// This method takes the message inside the `message` argument and buffer it if the channel is full. + /// This method returns `Pending` if the channel is full and `Ready(SendError)` if it is closed. + /// # Panics + /// Panics if call this method with `None` message in the first call. + /// # Examples + /// + /// ``` + /// use async_channel::{bounded, SendError}; + /// use futures_lite::future; + /// use std::task::Poll; + + /// future::block_on(async { + /// future::poll_fn(|cx| -> Poll<()> { + /// let (mut s, r) = bounded::(1); + /// assert_eq!(s.poll_send(cx, &mut Some(1)), Poll::Ready(Ok(()))); + /// assert_eq!(s.poll_send(cx, &mut Some(2)), Poll::Pending); + /// drop(r); + /// assert_eq!( + /// s.poll_send(cx, &mut Some(3)), + /// Poll::Ready(Err(SendError(3))) + /// ); + /// Poll::Ready(()) + /// }) + /// .await; + /// }); + /// ``` + pub fn poll_send( + &mut self, + cx: &mut Context<'_>, + msg: &mut Option, + ) -> Poll>> { + // take() the message when calling this function for the first time. + + if let Some(msg) = msg.take() { + self.sending_msg = Some(msg); + } + + loop { + // If this sink is listening for events, first wait for a notification. + if let Some(listener) = &mut self.listener { + futures_core::ready!(Pin::new(listener).poll(cx)); + self.listener = None; + } + + loop { + let message = self.sending_msg.take().unwrap(); + // Attempt to send the item immediately + match self.try_send(message) { + Ok(_) => { + // Great! The item has been sent sucessfully. + // The stream is not blocked on an event - drop the listener. + self.listener = None; + return Poll::Ready(Ok(())); + } + Err(e) => match e { + TrySendError::Full(item) => { + // The channel is full now. + // Store the item back to the struct for the next loop or polling. + self.sending_msg = Some(item); + } + TrySendError::Closed(item) => { + // The channel is closed. + // The stream is not blocked on an event - drop the listener. + self.listener = None; + return Poll::Ready(Err(SendError(item))); + } + }, + } + + // Receiving failed - now start listening for notifications or wait for one. + match &mut self.listener { + Some(_) => { + // Create a listener and try sending the message again. + break; + } + None => { + // Go back to the outer loop to poll the listener. + self.listener = Some(self.channel.sink_ops.listen()); + } + } + } + } + } } impl Drop for Sender { @@ -449,6 +548,8 @@ impl Clone for Sender { Sender { channel: self.channel.clone(), + listener: None, + sending_msg: None, } } } @@ -497,6 +598,8 @@ impl Receiver { // message or gets canceled, it will notify another blocked send operation. self.channel.send_ops.notify(1); + self.channel.sink_ops.notify(usize::MAX); + Ok(msg) } Err(PopError::Empty) => Err(TryRecvError::Empty), @@ -725,6 +828,7 @@ impl Drop for Receiver { if self.channel.receiver_count.fetch_sub(1, Ordering::AcqRel) == 1 { self.channel.close(); } + self.channel.sink_ops.notify(usize::MAX); } } diff --git a/tests/bounded.rs b/tests/bounded.rs index dfc96b3..ce50155 100644 --- a/tests/bounded.rs +++ b/tests/bounded.rs @@ -383,3 +383,108 @@ fn mpmc_stream() { assert_eq!(c.load(Ordering::SeqCst), THREADS); } } + +#[test] +fn poll_send() { + let (mut s, r) = bounded::(1); + + Parallel::new() + .add(|| { + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(7u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(8u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(9u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(10u32))) + .await + .unwrap(); + }); + }) + .add(|| { + sleep(ms(1500)); + assert_eq!(future::block_on(r.recv()), Ok(7)); + assert_eq!(future::block_on(r.recv()), Ok(8)); + assert_eq!(future::block_on(r.recv()), Ok(9)); + }) + .run(); +} + +#[test] +fn spsc_poll_send() { + const COUNT: usize = 25_000; + + let (s, r) = bounded::(3); + + Parallel::new() + .add({ + let mut r = r.clone(); + move || { + for _ in 0..COUNT { + future::block_on(r.next()).unwrap(); + } + } + }) + .add(|| { + let s = s.clone(); + for i in 0..COUNT { + let mut s = s.clone(); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(i))) + .await + .unwrap(); + }); + } + }) + .run(); +} + +#[test] +fn mpmc_poll_send() { + const COUNT: usize = 25_000; + const THREADS: usize = 4; + + let (s, r) = bounded::(3); + let v = (0..COUNT).map(|_| AtomicUsize::new(0)).collect::>(); + let v = &v; + + Parallel::new() + .each(0..THREADS, { + let mut r = r.clone(); + move |_| { + for _ in 0..COUNT { + let n = future::block_on(r.next()).unwrap(); + v[n].fetch_add(1, Ordering::SeqCst); + } + } + }) + .each(0..THREADS, |_| { + let s = s.clone(); + for i in 0..COUNT { + let mut s = s.clone(); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(i))) + .await + .unwrap(); + }); + } + }) + .run(); + + for c in v { + assert_eq!(c.load(Ordering::SeqCst), THREADS); + } +} diff --git a/tests/unbounded.rs b/tests/unbounded.rs index 50ed50b..758fb31 100644 --- a/tests/unbounded.rs +++ b/tests/unbounded.rs @@ -298,3 +298,108 @@ fn mpmc_stream() { assert_eq!(c.load(Ordering::SeqCst), THREADS); } } + +#[test] +fn poll_send() { + let (mut s, r) = unbounded(); + + Parallel::new() + .add(|| { + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(7u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(8u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(9u32))) + .await + .unwrap(); + }); + sleep(ms(1000)); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(10u32))) + .await + .unwrap(); + }); + }) + .add(|| { + sleep(ms(1500)); + assert_eq!(future::block_on(r.recv()), Ok(7)); + assert_eq!(future::block_on(r.recv()), Ok(8)); + assert_eq!(future::block_on(r.recv()), Ok(9)); + }) + .run(); +} + +#[test] +fn spsc_poll_send() { + const COUNT: usize = 25_000; + + let (s, r) = unbounded(); + + Parallel::new() + .add({ + let mut r = r.clone(); + move || { + for _ in 0..COUNT { + future::block_on(r.next()).unwrap(); + } + } + }) + .add(|| { + let s = s.clone(); + for i in 0..COUNT { + let mut s = s.clone(); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(i))) + .await + .unwrap(); + }); + } + }) + .run(); +} + +#[test] +fn mpmc_poll_send() { + const COUNT: usize = 25_000; + const THREADS: usize = 4; + + let (s, r) = unbounded::(); + let v = (0..COUNT).map(|_| AtomicUsize::new(0)).collect::>(); + let v = &v; + + Parallel::new() + .each(0..THREADS, { + let mut r = r.clone(); + move |_| { + for _ in 0..COUNT { + let n = future::block_on(r.next()).unwrap(); + v[n].fetch_add(1, Ordering::SeqCst); + } + } + }) + .each(0..THREADS, |_| { + let s = s.clone(); + for i in 0..COUNT { + let mut s = s.clone(); + future::block_on(async { + future::poll_fn(|cx| s.poll_send(cx, &mut Some(i))) + .await + .unwrap(); + }); + } + }) + .run(); + + for c in v { + assert_eq!(c.load(Ordering::SeqCst), THREADS); + } +}