diff --git a/Cargo.toml b/Cargo.toml index 39a2d77..4970af4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tokio_kcp" -version = "0.8.2" +version = "0.9.0" authors = ["Matrix <113445886@qq.com>", "Y. T. Chung "] description = "A kcp implementation for tokio" license = "MIT" @@ -13,7 +13,7 @@ edition = "2018" [dependencies] bytes = "1.1" futures = "0.3" -kcp = "0.4.16" +kcp = "0.5.0" log = "0.4" tokio = { version = "1.11", features = ["net", "sync", "rt", "macros", "time"] } byte_string = "1" diff --git a/src/session.rs b/src/session.rs index a3ddc46..8a3ebb3 100644 --- a/src/session.rs +++ b/src/session.rs @@ -92,7 +92,8 @@ impl KcpSession { Ok(n) => { let input_buffer = &input_buffer[..n]; let input_conv = kcp::get_conv(input_buffer); - trace!("[SESSION] UDP recv {} bytes, conv: {}, going to input {:?}", n, input_conv, ByteStr::new(input_buffer)); + trace!("[SESSION] UDP recv {} bytes, conv: {}, going to input {:?}", + n, input_conv, ByteStr::new(input_buffer)); let mut socket = session.socket.lock(); @@ -108,7 +109,8 @@ impl KcpSession { } Ok(false) => {} Err(err) => { - error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", n, err, ByteStr::new(input_buffer)); + error!("[SESSION] UDP input {} bytes error: {}, input buffer {:?}", + n, err, ByteStr::new(input_buffer)); } } } @@ -120,8 +122,11 @@ impl KcpSession { if let Some(input_buffer) = input_opt { let mut socket = session.socket.lock(); match socket.input(&input_buffer) { - Ok(..) => { - trace!("[SESSION] UDP input {} bytes from channel {:?}", input_buffer.len(), ByteStr::new(&input_buffer)); + Ok(waked) => { + // trace!("[SESSION] UDP input {} bytes from channel {:?}", + // input_buffer.len(), ByteStr::new(&input_buffer)); + trace!("[SESSION] UDP input {} bytes from channel, waked? {} sender/receiver", + input_buffer.len(), waked); } Err(err) => { error!("[SESSION] UDP input {} bytes from channel failed, error: {}, input buffer {:?}", @@ -177,6 +182,11 @@ impl KcpSession { } } + // If window is full, flush it immediately + if socket.need_flush() { + let _ = socket.flush(); + } + match socket.update() { Ok(next_next) => Instant::from_std(next_next), Err(err) => { diff --git a/src/skcp.rs b/src/skcp.rs index 46f7546..e0d4c94 100644 --- a/src/skcp.rs +++ b/src/skcp.rs @@ -203,16 +203,27 @@ impl KcpSocket { } match self.kcp.recv(buf) { - Ok(0) | Err(KcpError::RecvQueueEmpty) | Err(KcpError::ExpectingFragment) => { + e @ (Ok(0) | Err(KcpError::RecvQueueEmpty) | Err(KcpError::ExpectingFragment)) => { + trace!( + "[RECV] rcvwnd={} peeksize={} r={:?}", + self.kcp.rcv_wnd(), + self.kcp.peeksize().unwrap_or(0), + e + ); + if let Some(waker) = self.pending_receiver.replace(cx.waker().clone()) { if !cx.waker().will_wake(&waker) { waker.wake(); } } + Poll::Pending } Err(err) => Err(err).into(), - Ok(n) => Ok(n).into(), + Ok(n) => { + self.last_update = Instant::now(); + Ok(n).into() + } } } @@ -302,6 +313,11 @@ impl KcpSocket { pub fn last_update_time(&self) -> Instant { self.last_update } + + pub fn need_flush(&self) -> bool { + (self.kcp.wait_snd() >= self.kcp.snd_wnd() as usize || self.kcp.wait_snd() >= self.kcp.rmt_wnd() as usize) + && !self.kcp.waiting_conv() + } } #[cfg(test)] diff --git a/src/stream.rs b/src/stream.rs index 8a2d9e2..2eb406c 100644 --- a/src/stream.rs +++ b/src/stream.rs @@ -90,7 +90,6 @@ impl KcpStream { match ready!(kcp.poll_recv(cx, buf)) { Ok(n) => { trace!("[CLIENT] recv directly {} bytes", n); - self.session.notify(); return Ok(n).into(); } Err(KcpError::UserBufTooSmall) => {} @@ -108,7 +107,6 @@ impl KcpStream { Ok(0) => return Ok(0).into(), Ok(n) => { trace!("[CLIENT] recv buffered {} bytes", n); - self.session.notify(); self.recv_buffer_pos = 0; self.recv_buffer_cap = n; } diff --git a/src/utils.rs b/src/utils.rs index 5ab83a9..93cd5ec 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -4,5 +4,6 @@ use std::time::{SystemTime, UNIX_EPOCH}; pub fn now_millis() -> u32 { let start = SystemTime::now(); let since_the_epoch = start.duration_since(UNIX_EPOCH).expect("time went afterwards"); - (since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64 / 1_000_000) as u32 + // (since_the_epoch.as_secs() * 1000 + since_the_epoch.subsec_millis() as u64) as u32 + since_the_epoch.as_millis() as u32 }