From 52e5eaaec63864ff342a90ff8924b0f79c1d5416 Mon Sep 17 00:00:00 2001 From: Joe Grund Date: Tue, 12 Sep 2023 13:13:17 -0400 Subject: [PATCH] Use `ChannelMsg::WindowAdjusted` during data transfer OpenSSH server sends `CHANNEL_WINDOW_ADJUST` messages before window_size is 0. Handle these message at each turn of the loop within `Channel.send_data` Signed-off-by: Joe Grund --- russh/src/channels.rs | 40 +++++++++++++++++++++++++---------- russh/src/client/encrypted.rs | 15 ++++++++----- russh/src/server/encrypted.rs | 4 ++-- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/russh/src/channels.rs b/russh/src/channels.rs index e5da311e..33dc3e19 100644 --- a/russh/src/channels.rs +++ b/russh/src/channels.rs @@ -1,6 +1,6 @@ -use russh_cryptovec::CryptoVec; -use tokio::sync::mpsc::{Sender, UnboundedReceiver}; use log::debug; +use russh_cryptovec::CryptoVec; +use tokio::sync::mpsc::{error::TryRecvError, Sender, UnboundedReceiver}; use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig}; @@ -290,23 +290,38 @@ impl + Send + 'static> Channel { while self.window_size == 0 { match self.receiver.recv().await { Some(ChannelMsg::WindowAdjusted { new_size }) => { - debug!("window adjusted: {:?}", new_size); + debug!("channel {} => window adjusted: {new_size}", self.id); self.window_size = new_size; break; } Some(msg) => { - debug!("unexpected channel msg: {:?}", msg); + debug!("channel {} => unexpected channel msg: {msg:?}", self.id); } None => break, } } + + // Some implementations send CHANNEL_WINDOW_ADJUST prior to + // window size being 0. Process those at each turn of the loop here. + match self.receiver.try_recv() { + Ok(ChannelMsg::WindowAdjusted { new_size }) => { + debug!("channel {} => window adjusted: {new_size}", self.id); + self.window_size = new_size; + } + Ok(msg) => { + debug!("channel {} => unexpected channel msg: {msg:?}", self.id); + } + Err(TryRecvError::Empty | TryRecvError::Disconnected) => {} + } + debug!( - "sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}", - self.window_size, self.max_packet_size, total + "channel {} => sending data, self.window_size = {}, self.max_packet_size = {}, total = {total}", + self.id, self.window_size, self.max_packet_size ); - let sendable = self.window_size.min(self.max_packet_size) as usize; - debug!("sendable {:?}", sendable); + let sendable = self.writable_packet_size(); + + debug!("channel {} => sendable {sendable}", self.id); // If we can not send anymore, continue // and wait for server window adjustment @@ -318,14 +333,16 @@ impl + Send + 'static> Channel { let n = data.read(&mut c[..]).await?; total += n; c.resize(n); - self.window_size -= n as u32; + + self.window_size = self.window_size.saturating_sub(n as u32); + self.send_data_packet(ext, c).await?; + if n == 0 { break; - } else if self.window_size > 0 { - continue; } } + Ok(()) } @@ -349,6 +366,7 @@ impl + Send + 'static> Channel { match self.receiver.recv().await { Some(ChannelMsg::WindowAdjusted { new_size }) => { self.window_size = new_size; + Some(ChannelMsg::WindowAdjusted { new_size }) } Some(msg) => Some(msg), diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index cc7df2b4..609782fe 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -613,26 +613,31 @@ impl Session { } Some(&msg::CHANNEL_WINDOW_ADJUST) => { debug!("channel_window_adjust"); + let mut r = buf.reader(1); let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?); let amount = r.read_u32().map_err(crate::Error::from)?; let mut new_size = 0; + debug!("amount: {:?}", amount); + if let Some(ref mut enc) = self.common.encrypted { if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) { - channel.recipient_window_size += amount; + channel.recipient_window_size = + channel.recipient_window_size.saturating_add(amount); + new_size = channel.recipient_window_size; } else { return Err(crate::Error::WrongChannel.into()); } - } - if let Some(ref mut enc) = self.common.encrypted { - new_size -= enc.flush_pending(channel_num) as u32; + new_size = new_size.saturating_sub(enc.flush_pending(channel_num) as u32); } + if let Some(chan) = self.channels.get(&channel_num) { - let _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); + _ = chan.send(ChannelMsg::WindowAdjusted { new_size }); } + client.window_adjusted(channel_num, new_size, self).await } Some(&msg::GLOBAL_REQUEST) => { diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index abfa2555..df2cf3fe 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -47,7 +47,7 @@ impl Session { // Either this packet is a KEXINIT, in which case we start a key re-exchange. #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); + let enc = self.common.encrypted.as_mut().unwrap(); if buf.first() == Some(&msg::KEXINIT) { debug!("Received rekeying request"); // If we're not currently rekeying, but `buf` is a rekey request @@ -143,7 +143,7 @@ impl Session { }; #[allow(clippy::unwrap_used)] - let mut enc = self.common.encrypted.as_mut().unwrap(); + let enc = self.common.encrypted.as_mut().unwrap(); // If we've successfully read a packet. match enc.state { EncryptedState::WaitingAuthServiceRequest {