From c797274f9dc38540e6be50582c11f8b8eb9e205e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9on=20ROUX?= Date: Wed, 13 Sep 2023 14:33:18 +0200 Subject: [PATCH] Add support for ExtendedData to io::{ChannelTx, ChannelRx} --- russh/src/channels/io/rx.rs | 39 +++++++++++++++---------------------- russh/src/channels/io/tx.rs | 13 +++++++++---- russh/src/channels/mod.rs | 19 ++++++++++++++---- 3 files changed, 40 insertions(+), 31 deletions(-) diff --git a/russh/src/channels/io/rx.rs b/russh/src/channels/io/rx.rs index eda22772..18f5ec35 100644 --- a/russh/src/channels/io/rx.rs +++ b/russh/src/channels/io/rx.rs @@ -1,14 +1,10 @@ use std::{ io, pin::Pin, - sync::Arc, task::{Context, Poll}, }; -use tokio::{ - io::AsyncRead, - sync::{mpsc::error::TryRecvError, Mutex}, -}; +use tokio::{io::AsyncRead, sync::mpsc::error::TryRecvError}; use super::ChannelMsg; use crate::{Channel, ChannelId}; @@ -21,18 +17,18 @@ where channel: &'i mut Channel, buffer: Option, - window_size: Arc>, + ext: Option, } impl<'i, S> ChannelRx<'i, S> where S: From<(ChannelId, ChannelMsg)>, { - pub fn new(channel: &'i mut Channel, window_size: Arc>) -> Self { + pub fn new(channel: &'i mut Channel, ext: Option) -> Self { Self { channel, buffer: None, - window_size, + ext, } } } @@ -60,8 +56,8 @@ where }, }; - match &msg { - ChannelMsg::Data { data } => { + match (&msg, self.ext) { + (ChannelMsg::Data { data }, None) => { if buf.remaining() >= data.len() { buf.put_slice(data); @@ -73,22 +69,19 @@ where Poll::Pending } } - ChannelMsg::WindowAdjusted { new_size } => { - let buffer = match self.window_size.try_lock() { - Ok(mut window_size) => { - *window_size = *new_size; - - None - } - Err(_) => Some(msg), - }; + (ChannelMsg::ExtendedData { data, ext }, Some(target)) if *ext == target => { + if buf.remaining() >= data.len() { + buf.put_slice(data); - self.buffer = buffer; + Poll::Ready(Ok(())) + } else { + self.buffer = Some(msg); - cx.waker().wake_by_ref(); - Poll::Pending + cx.waker().wake_by_ref(); + Poll::Pending + } } - ChannelMsg::Eof => { + (ChannelMsg::Eof, _) => { self.channel.receiver.close(); Poll::Ready(Ok(())) diff --git a/russh/src/channels/io/tx.rs b/russh/src/channels/io/tx.rs index f603b749..1d6fa076 100644 --- a/russh/src/channels/io/tx.rs +++ b/russh/src/channels/io/tx.rs @@ -24,6 +24,7 @@ pub struct ChannelTx { window_size: Arc>, max_packet_size: u32, + ext: Option, } impl ChannelTx { @@ -32,12 +33,14 @@ impl ChannelTx { id: ChannelId, window_size: Arc>, max_packet_size: u32, + ext: Option, ) -> Self { Self { sender, id, window_size, max_packet_size, + ext, } } } @@ -73,10 +76,12 @@ where *window_size -= writable as u32; drop(window_size); - match self - .sender - .try_send((self.id, ChannelMsg::Data { data }).into()) - { + let msg = match self.ext { + None => ChannelMsg::Data { data }, + Some(ext) => ChannelMsg::ExtendedData { data, ext }, + }; + + match self.sender.try_send((self.id, msg).into()) { Ok(_) => Poll::Ready(Ok(writable)), Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)), Err(TrySendError::Full(_)) => { diff --git a/russh/src/channels/mod.rs b/russh/src/channels/mod.rs index a3a6a358..441c3102 100644 --- a/russh/src/channels/mod.rs +++ b/russh/src/channels/mod.rs @@ -413,9 +413,19 @@ impl + Send + 'static> Channel { stream } - /// Setup the [`Channel`] to be able to send messages through [`io::ChannelTx`], - /// and receiving them through [`io::ChannelRx`]. + /// Setup the [`Channel`] to be able to send and receive [`ChannelMsg::Data`] + /// through [`io::ChannelTx`] and [`io::ChannelRx`]. pub fn into_io_parts(&mut self) -> (io::ChannelTx, io::ChannelRx<'_, S>) { + self.into_io_parts_ext(None) + } + + /// Setup the [`Channel`] to be able to send and receive [`ChannelMsg::Data`] + /// or [`ChannelMsg::ExtendedData`] through [`io::ChannelTx`] and [`io::ChannelRx`] + /// depending on the `ext` parameter. + pub fn into_io_parts_ext( + &mut self, + ext: Option, + ) -> (io::ChannelTx, io::ChannelRx<'_, S>) { use std::sync::Arc; use tokio::sync::Mutex; @@ -425,10 +435,11 @@ impl + Send + 'static> Channel { io::ChannelTx::new( self.sender.clone(), self.id, - window_size.clone(), + window_size, self.max_packet_size, + ext, ), - io::ChannelRx::new(self, window_size), + io::ChannelRx::new(self, ext), ) } }