diff --git a/russh/src/channels/channel_ref.rs b/russh/src/channels/channel_ref.rs new file mode 100644 index 00000000..d924bb11 --- /dev/null +++ b/russh/src/channels/channel_ref.rs @@ -0,0 +1,35 @@ +use std::sync::Arc; + +use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; + +use crate::ChannelMsg; + +/// A handle to the [`super::Channel`]'s to be able to transmit messages +/// to it and update it's `window_size`. +#[derive(Debug)] +pub struct ChannelRef { + pub(super) sender: UnboundedSender, + pub(super) window_size: Arc>, +} + +impl ChannelRef { + pub fn new(sender: UnboundedSender) -> Self { + Self { + sender, + window_size: Default::default(), + } + } + + pub fn window_size(&self) -> &Arc> { + &self.window_size + } +} + +impl std::ops::Deref for ChannelRef { + type Target = UnboundedSender; + + fn deref(&self) -> &Self::Target { + &self.sender + } +} diff --git a/russh/src/channels/io/rx.rs b/russh/src/channels/io/rx.rs index 18f5ec35..4982b67d 100644 --- a/russh/src/channels/io/rx.rs +++ b/russh/src/channels/io/rx.rs @@ -1,10 +1,9 @@ -use std::{ - io, - pin::Pin, - task::{Context, Poll}, -}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; -use tokio::{io::AsyncRead, sync::mpsc::error::TryRecvError}; +use tokio::io::AsyncRead; +use tokio::sync::mpsc::error::TryRecvError; use super::ChannelMsg; use crate::{Channel, ChannelId}; diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs index 1d6fa076..8757ed0b 100644 --- a/russh/src/channels/io/tx.rs +++ b/russh/src/channels/io/tx.rs @@ -1,19 +1,13 @@ -use std::{ - io, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; - -use tokio::{ - io::AsyncWrite, - sync::{ - mpsc::{self, error::TrySendError}, - Mutex, - }, -}; +use std::io; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; use russh_cryptovec::CryptoVec; +use tokio::io::AsyncWrite; +use tokio::sync::mpsc::error::TrySendError; +use tokio::sync::mpsc::{self}; +use tokio::sync::Mutex; use super::ChannelMsg; use crate::ChannelId; diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index edbfff63..3e949c7f 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -1,11 +1,16 @@ +use std::sync::Arc; + use russh_cryptovec::CryptoVec; use tokio::sync::mpsc::{Sender, UnboundedReceiver}; -use log::debug; +use tokio::sync::Mutex; use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig}; pub mod io; +mod channel_ref; +pub use channel_ref::ChannelRef; + #[derive(Debug)] #[non_exhaustive] /// Possible messages that [Channel::wait] can receive. @@ -113,7 +118,7 @@ pub struct Channel> { pub(crate) sender: Sender, pub(crate) receiver: UnboundedReceiver, pub(crate) max_packet_size: u32, - pub(crate) window_size: u32, + pub(crate) window_size: Arc>, } impl> std::fmt::Debug for Channel { @@ -123,14 +128,32 @@ impl> std::fmt::Debug for Channel { } impl + Send + 'static> Channel { - pub fn id(&self) -> ChannelId { - self.id + pub(crate) fn new( + id: ChannelId, + sender: Sender, + max_packet_size: u32, + window_size: u32, + ) -> (Self, ChannelRef) { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let window_size = Arc::new(Mutex::new(window_size)); + + ( + Self { + id, + sender, + receiver: rx, + max_packet_size, + window_size: window_size.clone(), + }, + ChannelRef { + sender: tx, + window_size, + }, + ) } - /// Returns the min between the maximum packet size and the - /// remaining window size in the channel. - pub fn writable_packet_size(&self) -> usize { - self.max_packet_size.min(self.window_size) as usize + pub fn id(&self) -> ChannelId { + self.id } /// Request a pseudo-terminal with the given characteristics. @@ -266,14 +289,14 @@ impl + Send + 'static> Channel { } /// Send data to a channel. - pub async fn data(&mut self, data: R) -> Result<(), Error> { + pub async fn data(&mut self, data: R) -> Result<(), Error> { self.send_data(None, data).await } /// Send data to a channel. The number of bytes added to the /// "sending pipeline" (to be processed by the event loop) is /// returned. - pub async fn extended_data( + pub async fn extended_data( &mut self, ext: u32, data: R, @@ -281,63 +304,15 @@ impl + Send + 'static> Channel { self.send_data(Some(ext), data).await } - async fn send_data( + async fn send_data( &mut self, ext: Option, mut data: R, ) -> Result<(), Error> { - let mut total = 0; - loop { - // wait for the window to be restored. - while self.window_size == 0 { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - debug!("window adjusted: {:?}", new_size); - self.window_size = new_size; - break; - } - Some(msg) => { - debug!("unexpected channel msg: {:?}", msg); - } - None => break, - } - } - debug!( - "sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}", - self.window_size, self.max_packet_size, total - ); - let sendable = self.window_size.min(self.max_packet_size) as usize; - - debug!("sendable {:?}", sendable); - - // If we can not send anymore, continue - // and wait for server window adjustment - if sendable == 0 { - continue; - } + let (mut tx, _) = self.into_io_parts_ext(ext); - let mut c = CryptoVec::new_zeroed(sendable); - let n = data.read(&mut c[..]).await?; - total += n; - c.resize(n); - self.window_size -= n as u32; - self.send_data_packet(ext, c).await?; - if n == 0 { - break; - } else if self.window_size > 0 { - continue; - } - } - Ok(()) - } + tokio::io::copy(&mut data, &mut tx).await?; - async fn send_data_packet(&mut self, ext: Option, data: CryptoVec) -> Result<(), Error> { - self.send_msg(if let Some(ext) = ext { - ChannelMsg::ExtendedData { ext, data } - } else { - ChannelMsg::Data { data } - }) - .await?; Ok(()) } @@ -348,14 +323,7 @@ impl + Send + 'static> Channel { /// Wait for data to come. pub async fn wait(&mut self) -> Option { - match self.receiver.recv().await { - Some(ChannelMsg::WindowAdjusted { new_size }) => { - self.window_size = new_size; - Some(ChannelMsg::WindowAdjusted { new_size }) - } - Some(msg) => Some(msg), - None => None, - } + self.receiver.recv().await } async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> { @@ -426,16 +394,11 @@ impl + Send + 'static> Channel { &mut self, ext: Option, ) -> (io::ChannelTx, io::ChannelRx<'_, S>) { - use std::sync::Arc; - use tokio::sync::Mutex; - - let window_size = Arc::new(Mutex::new(self.window_size)); - ( io::ChannelTx::new( self.sender.clone(), self.id, - window_size, + self.window_size.clone(), self.max_packet_size, ext, ), diff --git a/russh/src/client/encrypted.rs b/russh/src/client/encrypted.rs index 48a737d4..b1442834 100644 --- a/russh/src/client/encrypted.rs +++ b/russh/src/client/encrypted.rs @@ -19,7 +19,6 @@ use log::{debug, error, info, trace, warn}; use russh_cryptovec::CryptoVec; use russh_keys::encoding::{Encoding, Reader}; use russh_keys::key::parse_public_key; -use tokio::sync::mpsc::unbounded_channel; use crate::client::{Handler, Msg, Prompt, Reply, Session}; use crate::key::PubKey; @@ -813,15 +812,16 @@ impl Session { id: ChannelId, msg: &OpenChannelMessage, ) -> Channel { - let (sender, receiver) = unbounded_channel(); - self.channels.insert(id, sender); - Channel { + let (channel, channel_ref) = Channel::new( id, - sender: self.inbound_channel_sender.clone(), - receiver, - max_packet_size: msg.recipient_maximum_packet_size, - window_size: msg.recipient_window_size, - } + self.inbound_channel_sender.clone(), + msg.recipient_maximum_packet_size, + msg.recipient_window_size, + ); + + self.channels.insert(id, channel_ref); + + channel } pub(crate) fn write_auth_request_if_needed(&mut self, user: &str, meth: auth::Method) -> bool { diff --git a/russh/src/client/mod.rs b/russh/src/client/mod.rs index c82fb5f7..5435d687 100644 --- a/russh/src/client/mod.rs +++ b/russh/src/client/mod.rs @@ -89,15 +89,15 @@ use russh_keys::encoding::Reader; #[cfg(feature = "openssl")] use russh_keys::key::SignatureHash; use russh_keys::key::{self, parse_public_key, PublicKey}; -use tokio; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio::net::{TcpStream, ToSocketAddrs}; use tokio::pin; use tokio::sync::mpsc::{ channel, unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender, }; +use tokio::sync::Mutex; -use crate::channels::{Channel, ChannelMsg}; +use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::cipher::{self, clear, CipherPair, OpeningKey}; use crate::key::PubKey; use crate::session::{CommonSession, EncryptedState, Exchange, Kex, KexDhDone, KexInit, NewKeys}; @@ -118,7 +118,7 @@ pub struct Session { common: CommonSession>, receiver: Receiver, sender: UnboundedSender, - channels: HashMap>, + channels: HashMap, target_window_size: u32, pending_reads: Vec, pending_len: u32, @@ -162,23 +162,23 @@ pub enum Msg { data: CryptoVec, }, ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectStreamLocal { socket_path: String, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { want_reply: bool, @@ -418,6 +418,7 @@ impl Handle { async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, + window_size_ref: Arc>, ) -> Result, crate::Error> { loop { match receiver.recv().await { @@ -426,12 +427,14 @@ impl Handle { max_packet_size, window_size, }) => { + *window_size_ref.lock().await = window_size; + return Ok(Channel { id, sender: self.sender.clone(), receiver, max_packet_size, - window_size, + window_size: window_size_ref, }); } Some(ChannelMsg::OpenFailure(reason)) => { @@ -454,11 +457,15 @@ impl Handle { /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Request an X11 channel, on which the X11 protocol may be tunneled. @@ -468,15 +475,19 @@ impl Handle { originator_port: u32, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -495,17 +506,21 @@ impl Handle { originator_port: u32, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_direct_streamlocal>( @@ -513,14 +528,18 @@ impl Handle { socket_path: S, ) -> Result, crate::Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectStreamLocal { socket_path: socket_path.into(), - sender, + channel_ref, }) .await .map_err(|_| crate::Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn tcpip_forward>( @@ -860,24 +879,24 @@ impl Session { } Msg::Signed { .. } => {} Msg::AuthInfoResponse { .. } => {} - Msg::ChannelOpenSession { sender } => { + Msg::ChannelOpenSession { channel_ref } => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenX11 { originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, - sender, + channel_ref, } => { let id = self.channel_open_direct_tcpip( &host_to_connect, @@ -885,14 +904,14 @@ impl Session { &originator_address, originator_port, )?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::ChannelOpenDirectStreamLocal { socket_path, - sender, + channel_ref, } => { let id = self.channel_open_direct_streamlocal(&socket_path)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Msg::TcpIpForward { want_reply, diff --git a/russh/src/server/encrypted.rs b/russh/src/server/encrypted.rs index abfa2555..989dbf83 100644 --- a/russh/src/server/encrypted.rs +++ b/russh/src/server/encrypted.rs @@ -21,7 +21,6 @@ use negotiation::Select; use russh_keys::encoding::{Encoding, Position, Reader}; use russh_keys::key; use russh_keys::key::Verify; -use tokio::sync::mpsc::unbounded_channel; use tokio::time::Instant; use {msg, negotiation}; @@ -676,6 +675,8 @@ impl Session { enc.flush_pending(channel_num); } if let Some(chan) = self.channels.get(&channel_num) { + *chan.window_size().lock().await = new_size; + chan.send(ChannelMsg::WindowAdjusted { new_size }) .unwrap_or(()) } @@ -1058,20 +1059,18 @@ impl Session { pending_data: std::collections::VecDeque::new(), }; - let (sender, receiver) = unbounded_channel(); - let channel = Channel { - id: sender_channel, - sender: self.sender.sender.clone(), - receiver, - max_packet_size: channel_params.recipient_maximum_packet_size, - window_size: channel_params.recipient_window_size, - }; + let (channel, reference) = Channel::new( + sender_channel, + self.sender.sender.clone(), + channel_params.recipient_maximum_packet_size, + channel_params.recipient_window_size, + ); match &msg.typ { ChannelType::Session => { let mut result = handler.channel_open_session(channel, self).await; if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); + s.channels.insert(sender_channel, reference); s.finalize_channel_open(&msg, channel_params, *allowed); } result @@ -1084,7 +1083,7 @@ impl Session { .channel_open_x11(channel, originator_address, *originator_port, self) .await; if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); + s.channels.insert(sender_channel, reference); s.finalize_channel_open(&msg, channel_params, *allowed); } result @@ -1101,7 +1100,7 @@ impl Session { ) .await; if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); + s.channels.insert(sender_channel, reference); s.finalize_channel_open(&msg, channel_params, *allowed); } result @@ -1118,7 +1117,7 @@ impl Session { ) .await; if let Ok((_, allowed, s)) = &mut result { - s.channels.insert(sender_channel, sender); + s.channels.insert(sender_channel, reference); s.finalize_channel_open(&msg, channel_params, *allowed); } result diff --git a/russh/src/server/session.rs b/russh/src/server/session.rs index 7cbcfeb3..9361e579 100644 --- a/russh/src/server/session.rs +++ b/russh/src/server/session.rs @@ -4,10 +4,11 @@ use std::sync::Arc; use log::debug; use russh_keys::encoding::{Encoding, Reader}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; -use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}; +use tokio::sync::mpsc::{unbounded_channel, Receiver, Sender, UnboundedReceiver}; +use tokio::sync::Mutex; use super::*; -use crate::channels::{Channel, ChannelMsg}; +use crate::channels::{Channel, ChannelMsg, ChannelRef}; use crate::kex::EXTENSION_SUPPORT_AS_CLIENT; use crate::msg; @@ -19,31 +20,31 @@ pub struct Session { pub(crate) target_window_size: u32, pub(crate) pending_reads: Vec, pub(crate) pending_len: u32, - pub(crate) channels: HashMap>, + pub(crate) channels: HashMap, } #[derive(Debug)] pub enum Msg { ChannelOpenSession { - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenDirectTcpIp { host_to_connect: String, port_to_connect: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenForwardedTcpIp { connected_address: String, connected_port: u32, originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, ChannelOpenX11 { originator_address: String, originator_port: u32, - sender: UnboundedSender, + channel_ref: ChannelRef, }, TcpIpForward { address: String, @@ -170,11 +171,16 @@ impl Handle { /// `confirmed` field of the corresponding `Channel`. pub async fn channel_open_session(&self) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender - .send(Msg::ChannelOpenSession { sender }) + .send(Msg::ChannelOpenSession { channel_ref }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + + self.wait_channel_confirmation(receiver, window_size_ref) + .await } /// Open a TCP/IP forwarding channel. This is usually done when a @@ -190,17 +196,21 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenDirectTcpIp { host_to_connect: host_to_connect.into(), port_to_connect, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_forwarded_tcpip, B: Into>( @@ -211,17 +221,21 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenForwardedTcpIp { connected_address: connected_address.into(), connected_port, originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } pub async fn channel_open_x11>( @@ -230,20 +244,25 @@ impl Handle { originator_port: u32, ) -> Result, Error> { let (sender, receiver) = unbounded_channel(); + let channel_ref = ChannelRef::new(sender); + let window_size_ref = channel_ref.window_size().clone(); + self.sender .send(Msg::ChannelOpenX11 { originator_address: originator_address.into(), originator_port, - sender, + channel_ref, }) .await .map_err(|_| Error::SendError)?; - self.wait_channel_confirmation(receiver).await + self.wait_channel_confirmation(receiver, window_size_ref) + .await } async fn wait_channel_confirmation( &self, mut receiver: UnboundedReceiver, + window_size_ref: Arc>, ) -> Result, Error> { loop { match receiver.recv().await { @@ -252,12 +271,14 @@ impl Handle { max_packet_size, window_size, }) => { + *window_size_ref.lock().await = window_size; + return Ok(Channel { id, sender: self.sender.clone(), receiver, max_packet_size, - window_size, + window_size: window_size_ref, }); } Some(ChannelMsg::OpenFailure(reason)) => { @@ -420,21 +441,21 @@ impl Session { Some(Msg::Channel(id, ChannelMsg::WindowAdjusted { new_size })) => { debug!("window adjusted to {:?} for channel {:?}", new_size, id); } - Some(Msg::ChannelOpenSession { sender }) => { + Some(Msg::ChannelOpenSession { channel_ref }) => { let id = self.channel_open_session()?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_direct_tcpip(&host_to_connect, port_to_connect, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenForwardedTcpIp { connected_address, connected_port, originator_address, originator_port, channel_ref }) => { let id = self.channel_open_forwarded_tcpip(&connected_address, connected_port, &originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } - Some(Msg::ChannelOpenX11 { originator_address, originator_port, sender }) => { + Some(Msg::ChannelOpenX11 { originator_address, originator_port, channel_ref }) => { let id = self.channel_open_x11(&originator_address, originator_port)?; - self.channels.insert(id, sender); + self.channels.insert(id, channel_ref); } Some(Msg::TcpIpForward { address, port }) => { self.tcpip_forward(&address, port);