Skip to content

Commit

Permalink
ssh/Channel: add send window counter
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxKellermann committed Oct 23, 2023
1 parent 76b6db8 commit 474ffb7
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 12 deletions.
58 changes: 48 additions & 10 deletions src/SessionChannel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ SessionChannel::SetEnv(std::string_view name, std::string_view value) noexcept
env.emplace_front(fmt::format("{}={}", name, value));
}

void
SessionChannel::OnWindowAdjust(std::size_t nbytes)
{
if (GetSendWindow() == 0)
/* re-schedule all read events, because we are now
allowed to send data again */
ScheduleRead();

Channel::OnWindowAdjust(nbytes);
}

void
SessionChannel::OnData(std::span<const std::byte> payload)
{
Expand Down Expand Up @@ -130,8 +141,6 @@ SessionChannel::Exec(const char *cmd)
if (tty.IsDefined()) {
p.stdin_fd = p.stdout_fd = p.stderr_fd = slave_tty.Release();
p.tty = true;

tty.ScheduleRead();
} else {
UniqueFileDescriptor stdin_r, stdout_r, stdout_w, stderr_r, stderr_w;
if (!UniqueFileDescriptor::CreatePipe(stdin_r, stdin_pipe) ||
Expand All @@ -144,11 +153,12 @@ SessionChannel::Exec(const char *cmd)
p.SetStderr(std::move(stderr_w));

stdout_pipe.Open(stdout_r.Release());
stdout_pipe.ScheduleRead();
stderr_pipe.Open(stderr_r.Release());
stderr_pipe.ScheduleRead();
}

if (GetSendWindow() > 0)
ScheduleRead();

if (const char *mount_home = p.ns.mount.GetMountHome()) {
p.SetEnv("HOME", mount_home);
p.chdir = mount_home;
Expand Down Expand Up @@ -218,9 +228,19 @@ void
SessionChannel::OnTtyReady([[maybe_unused]] unsigned events) noexcept
{
std::byte buffer[4096];
auto nbytes = tty.GetFileDescriptor().Read(buffer);
std::span<std::byte> dest{buffer};

if (GetSendWindow() < dest.size()) {
dest = dest.first(GetSendWindow());
assert(!dest.empty());
}

auto nbytes = tty.GetFileDescriptor().Read(dest);
if (nbytes > 0) {
SendData(std::span{buffer}.first(nbytes));
SendData(dest.first(nbytes));

if (GetSendWindow() == 0)
CancelRead();
} else {
tty.Close();
SendEof();
Expand All @@ -232,9 +252,18 @@ void
SessionChannel::OnStdoutReady([[maybe_unused]] unsigned events) noexcept
{
std::byte buffer[4096];
auto nbytes = stdout_pipe.GetFileDescriptor().Read(buffer);
std::span<std::byte> dest{buffer};
if (GetSendWindow() < dest.size()) {
dest = dest.first(GetSendWindow());
assert(!dest.empty());
}

auto nbytes = stdout_pipe.GetFileDescriptor().Read(dest);
if (nbytes > 0) {
SendData(std::span{buffer}.first(nbytes));
SendData(dest.first(nbytes));

if (GetSendWindow() == 0)
CancelRead();
} else {
stdout_pipe.Close();
SendEof();
Expand All @@ -246,9 +275,18 @@ void
SessionChannel::OnStderrReady([[maybe_unused]] unsigned events) noexcept
{
std::byte buffer[4096];
auto nbytes = stderr_pipe.GetFileDescriptor().Read(buffer);
std::span<std::byte> dest{buffer};
if (GetSendWindow() < dest.size()) {
dest = dest.first(GetSendWindow());
assert(!dest.empty());
}

auto nbytes = stderr_pipe.GetFileDescriptor().Read(dest);
if (nbytes > 0) {
SendStderr(std::span{buffer}.first(nbytes));
SendStderr(dest.first(nbytes));

if (GetSendWindow() == 0)
CancelRead();
} else {
stderr_pipe.Close();
CloseIfInactive();
Expand Down
16 changes: 16 additions & 0 deletions src/SessionChannel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public:
~SessionChannel() noexcept override;

/* virtual methods from class SSH::Channel */
void OnWindowAdjust(std::size_t nbytes) override;
void OnData(std::span<const std::byte> payload) override;
void OnEof() override;
bool OnRequest(std::string_view request_type,
Expand All @@ -68,6 +69,21 @@ private:

void Exec(const char *cmd);

void CancelRead() noexcept {
stdout_pipe.CancelRead();
stderr_pipe.CancelRead();
tty.CancelRead();
}

void ScheduleRead() noexcept {
if (stdout_pipe.IsDefined())
stdout_pipe.ScheduleRead();
if (stderr_pipe.IsDefined())
stderr_pipe.ScheduleRead();
if (tty.IsDefined())
tty.ScheduleRead();
}

void OnTtyReady(unsigned events) noexcept;
void OnStdoutReady(unsigned events) noexcept;
void OnStderrReady(unsigned events) noexcept;
Expand Down
3 changes: 3 additions & 0 deletions src/ssh/CConnection.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ CConnection::CloseChannel(Channel &channel) noexcept
const ChannelInit init{
.local_channel = local_channel,
.peer_channel = peer_channel,
.send_window = channel.GetSendWindow(),
};

channels[local_channel] = new TombstoneChannel(*this, init);
Expand Down Expand Up @@ -101,6 +102,7 @@ CConnection::HandleChannelOpen(std::span<const std::byte> payload)

const auto channel_type = d.ReadString();
const uint_least32_t peer_channel = d.ReadU32();
const uint_least32_t initial_window_size = d.ReadU32();

const uint_least32_t local_channel = AllocateChannelIndex();
if (local_channel >= channels.size()) {
Expand All @@ -113,6 +115,7 @@ CConnection::HandleChannelOpen(std::span<const std::byte> payload)
const ChannelInit init{
.local_channel = local_channel,
.peer_channel = peer_channel,
.send_window = initial_window_size,
};

auto channel = OpenChannel(channel_type, init);
Expand Down
11 changes: 10 additions & 1 deletion src/ssh/Channel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,29 @@ Channel::Close() noexcept
void
Channel::SendData(std::span<const std::byte> src)
{
assert(src.size() < send_window);

PacketSerializer s{MessageNumber::CHANNEL_DATA};
s.WriteU32(GetPeerChannel());
s.WriteLengthEncoded(src);
connection.SendPacket(std::move(s));

send_window -= src.size();
}

void
Channel::SendExtendedData(ChannelExtendedDataType data_type,
std::span<const std::byte> src)
{
assert(src.size() < send_window);

PacketSerializer s{MessageNumber::CHANNEL_EXTENDED_DATA};
s.WriteU32(GetPeerChannel());
s.WriteU32(static_cast<uint_least32_t>(data_type));
s.WriteLengthEncoded(src);
connection.SendPacket(std::move(s));

send_window -= src.size();
}

void
Expand Down Expand Up @@ -77,8 +85,9 @@ Channel::SerializeOpenConfirmation([[maybe_unused]] Serializer &s) const
}

void
Channel::OnWindowAdjust([[maybe_unused]] std::size_t nbytes)
Channel::OnWindowAdjust(std::size_t nbytes)
{
send_window += nbytes;
}

void
Expand Down
15 changes: 14 additions & 1 deletion src/ssh/Channel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,27 @@ class Serializer;
*/
struct ChannelInit {
uint_least32_t local_channel, peer_channel;

std::size_t send_window;
};

class Channel {
CConnection &connection;

const uint_least32_t local_channel, peer_channel;

/**
* How much data are we allowed to send? If this reaches
* zero, then we need to wait for CHANNEL_WINDOW_ADJUST.
*/
std::size_t send_window;

public:
Channel(CConnection &_connection, ChannelInit init) noexcept
:connection(_connection),
local_channel(init.local_channel),
peer_channel(init.peer_channel) {}
peer_channel(init.peer_channel),
send_window(init.send_window) {}

virtual ~Channel() noexcept = default;

Expand All @@ -48,6 +57,10 @@ public:
return peer_channel;
}

std::size_t GetSendWindow() const noexcept {
return send_window;
}

void Close() noexcept;

void SendData(std::span<const std::byte> src);
Expand Down

0 comments on commit 474ffb7

Please sign in to comment.