Skip to content

Commit

Permalink
Add support for ExtendedData to io::{ChannelTx, ChannelRx}
Browse files Browse the repository at this point in the history
  • Loading branch information
Léon ROUX committed Sep 13, 2023
1 parent 06f7c68 commit c797274
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 31 deletions.
39 changes: 16 additions & 23 deletions russh/src/channels/io/rx.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use tokio::{
io::AsyncRead,
sync::{mpsc::error::TryRecvError, Mutex},
};
use tokio::{io::AsyncRead, sync::mpsc::error::TryRecvError};

use super::ChannelMsg;
use crate::{Channel, ChannelId};
Expand All @@ -21,18 +17,18 @@ where
channel: &'i mut Channel<S>,
buffer: Option<ChannelMsg>,

window_size: Arc<Mutex<u32>>,
ext: Option<u32>,
}

impl<'i, S> ChannelRx<'i, S>
where
S: From<(ChannelId, ChannelMsg)>,
{
pub fn new(channel: &'i mut Channel<S>, window_size: Arc<Mutex<u32>>) -> Self {
pub fn new(channel: &'i mut Channel<S>, ext: Option<u32>) -> Self {
Self {
channel,
buffer: None,
window_size,
ext,
}
}
}
Expand Down Expand Up @@ -60,8 +56,8 @@ where
},
};

match &msg {
ChannelMsg::Data { data } => {
match (&msg, self.ext) {
(ChannelMsg::Data { data }, None) => {
if buf.remaining() >= data.len() {
buf.put_slice(data);

Expand All @@ -73,22 +69,19 @@ where
Poll::Pending
}
}
ChannelMsg::WindowAdjusted { new_size } => {
let buffer = match self.window_size.try_lock() {
Ok(mut window_size) => {
*window_size = *new_size;

None
}
Err(_) => Some(msg),
};
(ChannelMsg::ExtendedData { data, ext }, Some(target)) if *ext == target => {
if buf.remaining() >= data.len() {
buf.put_slice(data);

self.buffer = buffer;
Poll::Ready(Ok(()))
} else {
self.buffer = Some(msg);

cx.waker().wake_by_ref();
Poll::Pending
cx.waker().wake_by_ref();
Poll::Pending
}
}
ChannelMsg::Eof => {
(ChannelMsg::Eof, _) => {
self.channel.receiver.close();

Poll::Ready(Ok(()))
Expand Down
13 changes: 9 additions & 4 deletions russh/src/channels/io/tx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct ChannelTx<S> {

window_size: Arc<Mutex<u32>>,
max_packet_size: u32,
ext: Option<u32>,
}

impl<S> ChannelTx<S> {
Expand All @@ -32,12 +33,14 @@ impl<S> ChannelTx<S> {
id: ChannelId,
window_size: Arc<Mutex<u32>>,
max_packet_size: u32,
ext: Option<u32>,
) -> Self {
Self {
sender,
id,
window_size,
max_packet_size,
ext,
}
}
}
Expand Down Expand Up @@ -73,10 +76,12 @@ where
*window_size -= writable as u32;
drop(window_size);

match self
.sender
.try_send((self.id, ChannelMsg::Data { data }).into())
{
let msg = match self.ext {
None => ChannelMsg::Data { data },
Some(ext) => ChannelMsg::ExtendedData { data, ext },
};

match self.sender.try_send((self.id, msg).into()) {
Ok(_) => Poll::Ready(Ok(writable)),
Err(TrySendError::Closed(_)) => Poll::Ready(Ok(0)),
Err(TrySendError::Full(_)) => {
Expand Down
19 changes: 15 additions & 4 deletions russh/src/channels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,19 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
stream
}

/// Setup the [`Channel`] to be able to send messages through [`io::ChannelTx`],
/// and receiving them through [`io::ChannelRx`].
/// Setup the [`Channel`] to be able to send and receive [`ChannelMsg::Data`]
/// through [`io::ChannelTx`] and [`io::ChannelRx`].
pub fn into_io_parts(&mut self) -> (io::ChannelTx<S>, io::ChannelRx<'_, S>) {
self.into_io_parts_ext(None)
}

/// Setup the [`Channel`] to be able to send and receive [`ChannelMsg::Data`]
/// or [`ChannelMsg::ExtendedData`] through [`io::ChannelTx`] and [`io::ChannelRx`]
/// depending on the `ext` parameter.
pub fn into_io_parts_ext(
&mut self,
ext: Option<u32>,
) -> (io::ChannelTx<S>, io::ChannelRx<'_, S>) {
use std::sync::Arc;
use tokio::sync::Mutex;

Expand All @@ -425,10 +435,11 @@ impl<S: From<(ChannelId, ChannelMsg)> + Send + 'static> Channel<S> {
io::ChannelTx::new(
self.sender.clone(),
self.id,
window_size.clone(),
window_size,
self.max_packet_size,
ext,
),
io::ChannelRx::new(self, window_size),
io::ChannelRx::new(self, ext),
)
}
}

0 comments on commit c797274

Please sign in to comment.