Skip to content

Commit

Permalink
Moved handling of ChannelMsg::WindowAdjusted in the Session::server_r…
Browse files Browse the repository at this point in the history
…ead_authenticated() method
  • Loading branch information
lowlevl committed Sep 18, 2023
1 parent 2a426b4 commit 6ca4556
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 146 deletions.
34 changes: 34 additions & 0 deletions russh/src/channels/channel_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::sync::Arc;

use tokio::sync::{mpsc::UnboundedSender, Mutex};

use crate::ChannelMsg;

/// A handle to the [`super::Channel`]'s to be able to transmit messages
/// to it and update it's `window_size`.
#[derive(Debug)]
pub struct ChannelRef {
pub(super) sender: UnboundedSender<ChannelMsg>,
pub(super) window_size: Arc<Mutex<u32>>,
}

impl ChannelRef {
pub fn new(sender: UnboundedSender<ChannelMsg>) -> Self {
Self {
sender,
window_size: Default::default(),
}
}

pub fn window_size(&self) -> &Arc<Mutex<u32>> {
&self.window_size
}
}

impl std::ops::Deref for ChannelRef {
type Target = UnboundedSender<ChannelMsg>;

fn deref(&self) -> &Self::Target {
&self.sender
}
}
117 changes: 41 additions & 76 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
use log::debug;
use std::sync::Arc;

use russh_cryptovec::CryptoVec;
use tokio::sync::mpsc::{Sender, UnboundedReceiver};
use tokio::sync::{
mpsc::{Sender, UnboundedReceiver},
Mutex,
};

use crate::{ChannelId, ChannelOpenFailure, ChannelStream, Error, Pty, Sig};

pub mod io;

mod channel_ref;
pub use channel_ref::ChannelRef;

#[derive(Debug)]
#[non_exhaustive]
/// Possible messages that [Channel::wait] can receive.
Expand Down Expand Up @@ -113,7 +120,7 @@ pub struct Channel<Send: From<(ChannelId, ChannelMsg)>> {
pub(crate) sender: Sender<Send>,
pub(crate) receiver: UnboundedReceiver<ChannelMsg>,
pub(crate) max_packet_size: u32,
pub(crate) window_size: u32,
pub(crate) window_size: Arc<Mutex<u32>>,
}

impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
Expand All @@ -123,14 +130,32 @@ impl<T: From<(ChannelId, ChannelMsg)>> std::fmt::Debug for Channel<T> {
}

impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
pub fn id(&self) -> ChannelId {
self.id
pub(crate) fn new(
id: ChannelId,
sender: Sender<S>,
max_packet_size: u32,
window_size: u32,
) -> (Self, ChannelRef) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let window_size = Arc::new(Mutex::new(window_size));

(
Self {
id,
sender,
receiver: rx,
max_packet_size,
window_size: window_size.clone(),
},
ChannelRef {
sender: tx,
window_size,
},
)
}

/// Returns the min between the maximum packet size and the
/// remaining window size in the channel.
pub fn writable_packet_size(&self) -> usize {
self.max_packet_size.min(self.window_size) as usize
pub fn id(&self) -> ChannelId {
self.id
}

/// Request a pseudo-terminal with the given characteristics.
Expand Down Expand Up @@ -266,78 +291,30 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
}

/// Send data to a channel.
pub async fn data<R: tokio::io::AsyncReadExt + Unpin>(&mut self, data: R) -> Result<(), Error> {
pub async fn data<R: tokio::io::AsyncRead + Unpin>(&mut self, data: R) -> Result<(), Error> {
self.send_data(None, data).await
}

/// Send data to a channel. The number of bytes added to the
/// "sending pipeline" (to be processed by the event loop) is
/// returned.
pub async fn extended_data<R: tokio::io::AsyncReadExt + Unpin>(
pub async fn extended_data<R: tokio::io::AsyncRead + Unpin>(
&mut self,
ext: u32,
data: R,
) -> Result<(), Error> {
self.send_data(Some(ext), data).await
}

async fn send_data<R: tokio::io::AsyncReadExt + Unpin>(
async fn send_data<R: tokio::io::AsyncRead + Unpin>(
&mut self,
ext: Option<u32>,
mut data: R,
) -> Result<(), Error> {
let mut total = 0;
loop {
// wait for the window to be restored.
while self.window_size == 0 {
match self.receiver.recv().await {
Some(ChannelMsg::WindowAdjusted { new_size }) => {
debug!("window adjusted: {:?}", new_size);
self.window_size = new_size;
break;
}
Some(msg) => {
debug!("unexpected channel msg: {:?}", msg);
}
None => break,
}
}
debug!(
"sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}",
self.window_size, self.max_packet_size, total
);
let sendable = self.window_size.min(self.max_packet_size) as usize;

debug!("sendable {:?}", sendable);

// If we can not send anymore, continue
// and wait for server window adjustment
if sendable == 0 {
continue;
}
let (mut tx, _) = self.into_io_parts_ext(ext);

let mut c = CryptoVec::new_zeroed(sendable);
let n = data.read(&mut c[..]).await?;
total += n;
c.resize(n);
self.window_size -= n as u32;
self.send_data_packet(ext, c).await?;
if n == 0 {
break;
} else if self.window_size > 0 {
continue;
}
}
Ok(())
}
tokio::io::copy(&mut data, &mut tx).await?;

async fn send_data_packet(&mut self, ext: Option<u32>, data: CryptoVec) -> Result<(), Error> {
self.send_msg(if let Some(ext) = ext {
ChannelMsg::ExtendedData { ext, data }
} else {
ChannelMsg::Data { data }
})
.await?;
Ok(())
}

Expand All @@ -348,14 +325,7 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {

/// Wait for data to come.
pub async fn wait(&mut self) -> Option<ChannelMsg> {
match self.receiver.recv().await {
Some(ChannelMsg::WindowAdjusted { new_size }) => {
self.window_size = new_size;
Some(ChannelMsg::WindowAdjusted { new_size })
}
Some(msg) => Some(msg),
None => None,
}
self.receiver.recv().await
}

async fn send_msg(&self, msg: ChannelMsg) -> Result<(), Error> {
Expand Down Expand Up @@ -426,16 +396,11 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
&mut self,
ext: Option<u32>,
) -> (io::ChannelTx<S>, io::ChannelRx<'_, S>) {
use std::sync::Arc;
use tokio::sync::Mutex;

let window_size = Arc::new(Mutex::new(self.window_size));

(
io::ChannelTx::new(
self.sender.clone(),
self.id,
window_size,
self.window_size.clone(),
self.max_packet_size,
ext,
),
Expand Down
18 changes: 9 additions & 9 deletions russh/src/client/encrypted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use log::{debug, error, info, trace, warn};
use russh_cryptovec::CryptoVec;
use russh_keys::encoding::{Encoding, Reader};
use russh_keys::key::parse_public_key;
use tokio::sync::mpsc::unbounded_channel;

use crate::client::{Handler, Msg, Prompt, Reply, Session};
use crate::key::PubKey;
Expand Down Expand Up @@ -813,15 +812,16 @@ impl Session {
id: ChannelId,
msg: &OpenChannelMessage,
) -> Channel<Msg> {
let (sender, receiver) = unbounded_channel();
self.channels.insert(id, sender);
Channel {
let (channel, channel_ref) = Channel::new(
id,
sender: self.inbound_channel_sender.clone(),
receiver,
max_packet_size: msg.recipient_maximum_packet_size,
window_size: msg.recipient_window_size,
}
self.inbound_channel_sender.clone(),
msg.recipient_maximum_packet_size,
msg.recipient_window_size,
);

self.channels.insert(id, channel_ref);

channel
}

pub(crate) fn write_auth_request_if_needed(&mut self, user: &str, meth: auth::Method) -> bool {
Expand Down
Loading

0 comments on commit 6ca4556

Please sign in to comment.