From 474ffb77931d698e1f01d46c754de8919b15763a Mon Sep 17 00:00:00 2001 From: Max Kellermann Date: Mon, 23 Oct 2023 10:01:31 +0200 Subject: [PATCH] ssh/Channel: add send window counter --- src/SessionChannel.cxx | 58 ++++++++++++++++++++++++++++++++++------- src/SessionChannel.hxx | 16 ++++++++++++ src/ssh/CConnection.cxx | 3 +++ src/ssh/Channel.cxx | 11 +++++++- src/ssh/Channel.hxx | 15 ++++++++++- 5 files changed, 91 insertions(+), 12 deletions(-) diff --git a/src/SessionChannel.cxx b/src/SessionChannel.cxx index 9445742..2daebbb 100644 --- a/src/SessionChannel.cxx +++ b/src/SessionChannel.cxx @@ -71,6 +71,17 @@ SessionChannel::SetEnv(std::string_view name, std::string_view value) noexcept env.emplace_front(fmt::format("{}={}", name, value)); } +void +SessionChannel::OnWindowAdjust(std::size_t nbytes) +{ + if (GetSendWindow() == 0) + /* re-schedule all read events, because we are now + allowed to send data again */ + ScheduleRead(); + + Channel::OnWindowAdjust(nbytes); +} + void SessionChannel::OnData(std::span payload) { @@ -130,8 +141,6 @@ SessionChannel::Exec(const char *cmd) if (tty.IsDefined()) { p.stdin_fd = p.stdout_fd = p.stderr_fd = slave_tty.Release(); p.tty = true; - - tty.ScheduleRead(); } else { UniqueFileDescriptor stdin_r, stdout_r, stdout_w, stderr_r, stderr_w; if (!UniqueFileDescriptor::CreatePipe(stdin_r, stdin_pipe) || @@ -144,11 +153,12 @@ SessionChannel::Exec(const char *cmd) p.SetStderr(std::move(stderr_w)); stdout_pipe.Open(stdout_r.Release()); - stdout_pipe.ScheduleRead(); stderr_pipe.Open(stderr_r.Release()); - stderr_pipe.ScheduleRead(); } + if (GetSendWindow() > 0) + ScheduleRead(); + if (const char *mount_home = p.ns.mount.GetMountHome()) { p.SetEnv("HOME", mount_home); p.chdir = mount_home; @@ -218,9 +228,19 @@ void SessionChannel::OnTtyReady([[maybe_unused]] unsigned events) noexcept { std::byte buffer[4096]; - auto nbytes = tty.GetFileDescriptor().Read(buffer); + std::span dest{buffer}; + + if (GetSendWindow() < dest.size()) { + dest = dest.first(GetSendWindow()); + assert(!dest.empty()); + } + + auto nbytes = tty.GetFileDescriptor().Read(dest); if (nbytes > 0) { - SendData(std::span{buffer}.first(nbytes)); + SendData(dest.first(nbytes)); + + if (GetSendWindow() == 0) + CancelRead(); } else { tty.Close(); SendEof(); @@ -232,9 +252,18 @@ void SessionChannel::OnStdoutReady([[maybe_unused]] unsigned events) noexcept { std::byte buffer[4096]; - auto nbytes = stdout_pipe.GetFileDescriptor().Read(buffer); + std::span dest{buffer}; + if (GetSendWindow() < dest.size()) { + dest = dest.first(GetSendWindow()); + assert(!dest.empty()); + } + + auto nbytes = stdout_pipe.GetFileDescriptor().Read(dest); if (nbytes > 0) { - SendData(std::span{buffer}.first(nbytes)); + SendData(dest.first(nbytes)); + + if (GetSendWindow() == 0) + CancelRead(); } else { stdout_pipe.Close(); SendEof(); @@ -246,9 +275,18 @@ void SessionChannel::OnStderrReady([[maybe_unused]] unsigned events) noexcept { std::byte buffer[4096]; - auto nbytes = stderr_pipe.GetFileDescriptor().Read(buffer); + std::span dest{buffer}; + if (GetSendWindow() < dest.size()) { + dest = dest.first(GetSendWindow()); + assert(!dest.empty()); + } + + auto nbytes = stderr_pipe.GetFileDescriptor().Read(dest); if (nbytes > 0) { - SendStderr(std::span{buffer}.first(nbytes)); + SendStderr(dest.first(nbytes)); + + if (GetSendWindow() == 0) + CancelRead(); } else { stderr_pipe.Close(); CloseIfInactive(); diff --git a/src/SessionChannel.hxx b/src/SessionChannel.hxx index 4f668f4..cb18f4a 100644 --- a/src/SessionChannel.hxx +++ b/src/SessionChannel.hxx @@ -49,6 +49,7 @@ public: ~SessionChannel() noexcept override; /* virtual methods from class SSH::Channel */ + void OnWindowAdjust(std::size_t nbytes) override; void OnData(std::span payload) override; void OnEof() override; bool OnRequest(std::string_view request_type, @@ -68,6 +69,21 @@ private: void Exec(const char *cmd); + void CancelRead() noexcept { + stdout_pipe.CancelRead(); + stderr_pipe.CancelRead(); + tty.CancelRead(); + } + + void ScheduleRead() noexcept { + if (stdout_pipe.IsDefined()) + stdout_pipe.ScheduleRead(); + if (stderr_pipe.IsDefined()) + stderr_pipe.ScheduleRead(); + if (tty.IsDefined()) + tty.ScheduleRead(); + } + void OnTtyReady(unsigned events) noexcept; void OnStdoutReady(unsigned events) noexcept; void OnStderrReady(unsigned events) noexcept; diff --git a/src/ssh/CConnection.cxx b/src/ssh/CConnection.cxx index e7743ba..86613b4 100644 --- a/src/ssh/CConnection.cxx +++ b/src/ssh/CConnection.cxx @@ -52,6 +52,7 @@ CConnection::CloseChannel(Channel &channel) noexcept const ChannelInit init{ .local_channel = local_channel, .peer_channel = peer_channel, + .send_window = channel.GetSendWindow(), }; channels[local_channel] = new TombstoneChannel(*this, init); @@ -101,6 +102,7 @@ CConnection::HandleChannelOpen(std::span payload) const auto channel_type = d.ReadString(); const uint_least32_t peer_channel = d.ReadU32(); + const uint_least32_t initial_window_size = d.ReadU32(); const uint_least32_t local_channel = AllocateChannelIndex(); if (local_channel >= channels.size()) { @@ -113,6 +115,7 @@ CConnection::HandleChannelOpen(std::span payload) const ChannelInit init{ .local_channel = local_channel, .peer_channel = peer_channel, + .send_window = initial_window_size, }; auto channel = OpenChannel(channel_type, init); diff --git a/src/ssh/Channel.cxx b/src/ssh/Channel.cxx index 8a1e07f..1d1e1dc 100644 --- a/src/ssh/Channel.cxx +++ b/src/ssh/Channel.cxx @@ -20,21 +20,29 @@ Channel::Close() noexcept void Channel::SendData(std::span src) { + assert(src.size() < send_window); + PacketSerializer s{MessageNumber::CHANNEL_DATA}; s.WriteU32(GetPeerChannel()); s.WriteLengthEncoded(src); connection.SendPacket(std::move(s)); + + send_window -= src.size(); } void Channel::SendExtendedData(ChannelExtendedDataType data_type, std::span src) { + assert(src.size() < send_window); + PacketSerializer s{MessageNumber::CHANNEL_EXTENDED_DATA}; s.WriteU32(GetPeerChannel()); s.WriteU32(static_cast(data_type)); s.WriteLengthEncoded(src); connection.SendPacket(std::move(s)); + + send_window -= src.size(); } void @@ -77,8 +85,9 @@ Channel::SerializeOpenConfirmation([[maybe_unused]] Serializer &s) const } void -Channel::OnWindowAdjust([[maybe_unused]] std::size_t nbytes) +Channel::OnWindowAdjust(std::size_t nbytes) { + send_window += nbytes; } void diff --git a/src/ssh/Channel.hxx b/src/ssh/Channel.hxx index 3997f27..0d41e89 100644 --- a/src/ssh/Channel.hxx +++ b/src/ssh/Channel.hxx @@ -21,6 +21,8 @@ class Serializer; */ struct ChannelInit { uint_least32_t local_channel, peer_channel; + + std::size_t send_window; }; class Channel { @@ -28,11 +30,18 @@ class Channel { const uint_least32_t local_channel, peer_channel; + /** + * How much data are we allowed to send? If this reaches + * zero, then we need to wait for CHANNEL_WINDOW_ADJUST. + */ + std::size_t send_window; + public: Channel(CConnection &_connection, ChannelInit init) noexcept :connection(_connection), local_channel(init.local_channel), - peer_channel(init.peer_channel) {} + peer_channel(init.peer_channel), + send_window(init.send_window) {} virtual ~Channel() noexcept = default; @@ -48,6 +57,10 @@ public: return peer_channel; } + std::size_t GetSendWindow() const noexcept { + return send_window; + } + void Close() noexcept; void SendData(std::span src);