diff --git a/CMakeLists.txt b/CMakeLists.txt index 7900e6d..993ef1a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_TOOLCHAIN_FILE ${CMAKE_SOURCE_DIR}/toolchain.cmake) # define project -project(socket_manager LANGUAGES C CXX VERSION 0.3.1) +project(socket_manager LANGUAGES C CXX VERSION 0.3.2) # set default build type as shared option(BUILD_SHARED_LIBS "Build using shared libraries" ON) diff --git a/Cargo.toml b/Cargo.toml index a9c77b5..e6f2691 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio-socket-manager" -version = "0.3.1" +version = "0.3.2" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,7 +8,7 @@ edition = "2021" crate-type = ["staticlib"] [dependencies] -async-ringbuf = "0.1.3" +async-ringbuf = "0.2.0-rc.1" dashmap = { version = "5.4.0", features = ["inline"] } libc = "0.2.146" socket2 = "0.5.3" diff --git a/src/msg_sender.rs b/src/msg_sender.rs index da0e63b..41a03f1 100644 --- a/src/msg_sender.rs +++ b/src/msg_sender.rs @@ -1,8 +1,9 @@ -use async_ringbuf::ring_buffer::AsyncRbWrite; -use async_ringbuf::{AsyncHeapConsumer, AsyncHeapProducer, AsyncHeapRb}; -use std::future::poll_fn; -use std::task::Poll::{self, Pending, Ready}; -use std::task::Waker; +use async_ringbuf::halves::{AsyncCons, AsyncProd}; +use async_ringbuf::traits::{AsyncObserver, AsyncProducer, Producer, Split}; +use async_ringbuf::AsyncHeapRb; +use std::sync::Arc; +use std::task::Poll::{Pending, Ready}; +use std::task::{Poll, Waker}; use tokio::runtime::Handle; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; @@ -14,6 +15,9 @@ pub(crate) enum SendCommand { Flush, } +pub type AsyncHeapProducer = AsyncProd>>; +pub type AsyncHeapConsumer = AsyncCons>>; + pub(crate) fn make_sender(handle: Handle) -> (MsgSender, MsgRcv) { let (cmd, cmd_recv) = unbounded_channel::(); let (rings_prd, rings) = unbounded_channel::>(); @@ -55,7 +59,7 @@ fn burst_write( bytes: &[u8], ) -> BurstWriteState { loop { - let n = buf.as_mut_base().push_slice(&bytes[*offset..]); + let n = buf.push_slice(&bytes[*offset..]); if n == 0 { // no bytes read, return break BurstWriteState::Pending; @@ -84,26 +88,19 @@ impl MsgSender { // unfinished, enter into future self.handle.clone().block_on(async { loop { + self.ring_buf.wait_vacant(1).await; + // check if closed + if self.ring_buf.is_closed() { + break Err(std::io::Error::new( + std::io::ErrorKind::Other, + "connection closed", + )); + } if let BurstWriteState::Finished = burst_write(&mut offset, &mut self.ring_buf, bytes) { return Ok(()); } - poll_fn(|cx| { - unsafe { self.ring_buf.as_base().rb().register_head_waker(cx.waker()) }; - if self.ring_buf.is_closed() { - Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "connection closed", - ))) - } else if self.ring_buf.is_full() { - Pending::> - } else { - // continue to loop until pending - Ready(Ok(())) - } - }) - .await?; } }) } @@ -125,7 +122,7 @@ impl MsgSender { // allocate new ring buffer if unable to write the entire message. let new_buf_size = RING_BUFFER_SIZE.max(bytes.len() - offset); let (mut ring_buf, ring) = AsyncHeapRb::::new(new_buf_size).split(); - ring_buf.as_mut_base().push_slice(&bytes[offset..]); + ring_buf.push_slice(&bytes[offset..]); self.rings_prd.send(ring).map_err(|e| { std::io::Error::new( std::io::ErrorKind::WriteZero, @@ -145,23 +142,28 @@ impl MsgSender { return Ready(Ok(0)); } let mut offset = 0usize; + let mut waker_registered = false; loop { + // check if closed + if self.ring_buf.is_closed() { + break Ready(Err(std::io::Error::new( + std::io::ErrorKind::Other, + "connection closed", + ))); + } // attempt to write as much as possible burst_write(&mut offset, &mut self.ring_buf, bytes); if offset > 0 { break Ready(Ok(offset)); } // offset = 0, prepare to wait - unsafe { self.ring_buf.as_base().rb().register_head_waker(&waker) }; - // check the pending state ensues. - if self.ring_buf.is_closed() { - break Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "connection closed", - ))); - } else if self.ring_buf.is_full() { + if waker_registered { break Pending; } + // register waker + self.ring_buf.register_read_waker(&waker); + waker_registered = true; + // try again to ensure no missing wake } } diff --git a/src/write.rs b/src/write.rs index c9d5cb1..af7a4d7 100644 --- a/src/write.rs +++ b/src/write.rs @@ -1,7 +1,8 @@ use crate::conn::ConnConfig; use crate::msg_sender::MsgRcv; use crate::read::MIN_MSG_BUFFER_SIZE; -use async_ringbuf::AsyncHeapConsumer; +use crate::AsyncHeapConsumer; +use async_ringbuf::traits::{AsyncConsumer, AsyncObserver, Consumer, Observer}; use std::time::Duration; use tokio::io::AsyncWriteExt; use tokio::net::tcp::OwnedWriteHalf; @@ -50,12 +51,12 @@ async fn handle_writer_auto_flush( biased; // !has_data => wait for has_data // has_data => wait for write_threshold - _ = ring.wait(if !has_data {1} else {MIN_MSG_BUFFER_SIZE}) => { + _ = ring.wait_occupied(if !has_data {1} else {MIN_MSG_BUFFER_SIZE}) => { if ring.is_closed() { break 'ring; } has_data = true; - if ring.len() >= MIN_MSG_BUFFER_SIZE { + if ring.occupied_len() >= MIN_MSG_BUFFER_SIZE { flush(&mut ring, &mut write).await?; has_data = false } @@ -105,7 +106,7 @@ async fn handle_writer_no_auto_flush( tokio::select! { biased; // buf threshold - _ = ring.wait(MIN_MSG_BUFFER_SIZE) => { + _ = ring.wait_occupied(MIN_MSG_BUFFER_SIZE) => { if ring.is_closed() { break 'ring; } @@ -137,10 +138,10 @@ async fn flush( write: &mut OwnedWriteHalf, ) -> std::io::Result<()> { loop { - let (left, _) = ring_buf.as_mut_base().as_slices(); + let (left, _) = ring_buf.as_slices(); if !left.is_empty() { let count = write.write(left).await?; - unsafe { ring_buf.as_mut_base().advance(count) }; + unsafe { ring_buf.advance_read_index(count) }; } else { // both empty, break return Ok(());