From f4dafcf40558d7e61b045e62c4156e1ebb96767d Mon Sep 17 00:00:00 2001 From: zonyitoo Date: Wed, 10 Jul 2024 22:46:58 +0800 Subject: [PATCH] feat: allow recv empty data segment (#37) --- src/config.rs | 3 +++ src/skcp.rs | 36 +++++++++++++++++++++------------- src/stream.rs | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 13 deletions(-) diff --git a/src/config.rs b/src/config.rs index 7849132..f526996 100644 --- a/src/config.rs +++ b/src/config.rs @@ -75,6 +75,8 @@ pub struct KcpConfig { pub flush_acks_input: bool, /// Stream mode pub stream: bool, + /// Allow recv 0 byte packet. KCP Segments with 0 byte data are skipped by default. + pub allow_recv_empty_packet: bool, } impl Default for KcpConfig { @@ -87,6 +89,7 @@ impl Default for KcpConfig { flush_write: false, flush_acks_input: false, stream: false, + allow_recv_empty_packet: false, } } } diff --git a/src/skcp.rs b/src/skcp.rs index 814ecaf..7306bc2 100644 --- a/src/skcp.rs +++ b/src/skcp.rs @@ -77,6 +77,7 @@ pub struct KcpSocket { pending_sender: Option, pending_receiver: Option, closed: bool, + allow_recv_empty_packet: bool, } impl KcpSocket { @@ -112,6 +113,7 @@ impl KcpSocket { pending_sender: None, pending_receiver: None, closed: false, + allow_recv_empty_packet: c.allow_recv_empty_packet, }) } @@ -204,28 +206,36 @@ impl KcpSocket { } match self.kcp.recv(buf) { - e @ (Ok(0) | Err(KcpError::RecvQueueEmpty) | Err(KcpError::ExpectingFragment)) => { + e @ (Err(KcpError::RecvQueueEmpty) | Err(KcpError::ExpectingFragment)) => { trace!( "[RECV] rcvwnd={} peeksize={} r={:?}", self.kcp.rcv_wnd(), self.kcp.peeksize().unwrap_or(0), e ); - - if let Some(waker) = self.pending_receiver.replace(cx.waker().clone()) { - if !cx.waker().will_wake(&waker) { - waker.wake(); - } - } - - Poll::Pending } - Err(err) => Err(err).into(), + Err(err) => return Err(err).into(), Ok(n) => { - self.last_update = Instant::now(); - Ok(n).into() + if n == 0 && !self.allow_recv_empty_packet { + trace!( + "[RECV] rcvwnd={} peeksize={} r=Ok(0)", + self.kcp.rcv_wnd(), + self.kcp.peeksize().unwrap_or(0), + ); + } else { + self.last_update = Instant::now(); + return Ok(n).into(); + } } } + + if let Some(waker) = self.pending_receiver.replace(cx.waker().clone()) { + if !cx.waker().will_wake(&waker) { + waker.wake(); + } + } + + Poll::Pending } #[allow(dead_code)] @@ -255,7 +265,7 @@ impl KcpSocket { if self.pending_receiver.is_some() { if let Ok(peek) = self.kcp.peeksize() { - if peek > 0 { + if self.allow_recv_empty_packet || peek > 0 { let waker = self.pending_receiver.take().unwrap(); waker.wake(); diff --git a/src/stream.rs b/src/stream.rs index cb9cd7b..aba911f 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -203,3 +203,57 @@ impl std::os::windows::io::AsRawSocket for KcpStream { kcp_socket.udp_socket().as_raw_socket() } } + +#[cfg(test)] +mod test { + use crate::KcpListener; + + use super::*; + + #[tokio::test] + async fn test_stream_echo() { + let _ = env_logger::try_init(); + + let config = KcpConfig::default(); + let server_addr = "127.0.0.1:5555".parse::().unwrap(); + + let mut listener = KcpListener::bind(config.clone(), server_addr).await.unwrap(); + let listener_hdl = tokio::spawn(async move { + loop { + let (mut stream, peer_addr) = listener.accept().await.unwrap(); + println!("accepted {}", peer_addr); + + tokio::spawn(async move { + let mut buffer = [0u8; 8192]; + loop { + match stream.recv(&mut buffer).await { + Ok(n) => { + println!("server recv: {:?}", &buffer[..n]); + let send_n = stream.send(&buffer[..n]).await.unwrap(); + println!("server sent: {}", send_n); + } + Err(err) => { + println!("recv error: {}", err); + break; + } + } + } + }); + } + }); + + let mut stream = KcpStream::connect(&config, server_addr).await.unwrap(); + + let test_payload = b"HELLO WORLD"; + stream.send(test_payload).await.unwrap(); + println!("client sent: {:?}", test_payload); + + let mut recv_buffer = [0u8; 1024]; + let recv_n = stream.recv(&mut recv_buffer).await.unwrap(); + println!("client recv: {:?}", &recv_buffer[..recv_n]); + assert_eq!(recv_n, test_payload.len()); + assert_eq!(&recv_buffer[..recv_n], test_payload); + + listener_hdl.abort(); + } +}