Skip to content

Commit

Permalink
SessionChannel: move translation server request to Connection::Handle…
Browse files Browse the repository at this point in the history
…UserauthRequest()

Prepare for authentication via translation server.
  • Loading branch information
MaxKellermann committed Oct 24, 2023
1 parent 9557241 commit cdcc7e2
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 34 deletions.
51 changes: 47 additions & 4 deletions src/Connection.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,30 @@
#include "ssh/Channel.hxx"
#include "net/UniqueSocketDescriptor.hxx"

#ifdef ENABLE_TRANSLATION
#include "translation/LoginGlue.hxx"
#include "translation/Response.hxx"
#include "AllocatorPtr.hxx"
#endif // ENABLE_TRANSLATION

#include <fmt/core.h>

using std::string_view_literals::operator""sv;

#ifdef ENABLE_TRANSLATION

struct Connection::Translation {
Allocator alloc;
TranslateResponse response;

Translation(Allocator &&_alloc,
TranslateResponse &&_response) noexcept
:alloc(std::move(_alloc)),
response(std::move(_response)) {}
};

#endif // ENABLE_TRANSLATION

Connection::Connection(Instance &_instance, Listener &_listener,
UniqueSocketDescriptor _fd,
const KeyList &_host_keys)
Expand All @@ -28,6 +48,18 @@ Connection::Connection(Instance &_instance, Listener &_listener,

Connection::~Connection() noexcept = default;

#ifdef ENABLE_TRANSLATION

const TranslateResponse *
Connection::GetTranslationResponse() const noexcept
{
return translation
? &translation->response
: nullptr;
}

#endif

std::unique_ptr<SSH::Channel>
Connection::OpenChannel(std::string_view channel_type,
SSH::ChannelInit init)
Expand All @@ -38,10 +70,6 @@ Connection::OpenChannel(std::string_view channel_type,
if (channel_type == "session"sv) {
CConnection &connection = *this;
return std::make_unique<SessionChannel>(instance.GetSpawnService(),
#ifdef ENABLE_TRANSLATION
instance.GetTranslationServer(),
listener.GetTag(),
#endif
connection, init);
} else
return SSH::CConnection::OpenChannel(channel_type, init);
Expand Down Expand Up @@ -69,6 +97,21 @@ Connection::HandleUserauthRequest(std::span<const std::byte> payload)
const auto new_username = d.ReadString();
fmt::print(stderr, "Userauth '{}'\n", new_username);

#ifdef ENABLE_TRANSLATION
if (const char *translation_server = instance.GetTranslationServer()) {
Allocator alloc;
auto response = TranslateLogin(alloc, translation_server,
"ssh"sv, listener.GetTag(),
new_username, {});

if (response.status != HttpStatus{})
throw std::runtime_error{"Translation server failed"};

translation = std::make_unique<Translation>(std::move(alloc),
std::move(response));
}
#endif // ENABLE_TRANSLATION

username.assign(new_username);

SendPacket(SSH::PacketSerializer{SSH::MessageNumber::USERAUTH_SUCCESS});
Expand Down
14 changes: 14 additions & 0 deletions src/Connection.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

#include "ssh/CConnection.hxx"
#include "util/IntrusiveList.hxx"
#include "config.h"

#include <memory>

struct TranslateResponse;
class Instance;
class Listener;
class RootLogger;
Expand All @@ -25,6 +29,11 @@ class Connection final

std::string username;

#ifdef ENABLE_TRANSLATION
struct Translation;
std::unique_ptr<Translation> translation;
#endif // ENABLE_TRANSLATION

public:
Connection(Instance &_instance, Listener &_listener,
UniqueSocketDescriptor fd,
Expand All @@ -39,6 +48,11 @@ public:
return username;
}

#ifdef ENABLE_TRANSLATION
[[gnu::pure]]
const TranslateResponse *GetTranslationResponse() const noexcept;
#endif

protected:
void Destroy() noexcept override {
delete this;
Expand Down
24 changes: 4 additions & 20 deletions src/SessionChannel.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "AllocatorPtr.hxx"

#ifdef ENABLE_TRANSLATION
#include "translation/LoginGlue.hxx"
#include "translation/Response.hxx"
#endif // ENABLE_TRANSLATION

Expand All @@ -30,18 +29,10 @@
using std::string_view_literals::operator""sv;

SessionChannel::SessionChannel(SpawnService &_spawn_service,
#ifdef ENABLE_TRANSLATION
const char *_translation_server,
std::string_view _listener_tag,
#endif
SSH::CConnection &_connection,
SSH::ChannelInit init) noexcept
:SSH::Channel(_connection, init, RECEIVE_WINDOW),
spawn_service(_spawn_service),
#ifdef ENABLE_TRANSLATION
translation_server(_translation_server),
listener_tag(_listener_tag),
#endif
stdout_pipe(_connection.GetEventLoop(), BIND_THIS_METHOD(OnStdoutReady)),
stderr_pipe(_connection.GetEventLoop(), BIND_THIS_METHOD(OnStderrReady)),
tty(_connection.GetEventLoop(), BIND_THIS_METHOD(OnTtyReady))
Expand Down Expand Up @@ -127,18 +118,11 @@ SessionChannel::Exec(const char *cmd)
const char *shell = cmd != nullptr ? "/bin/sh" : "/bin/bash";

#ifdef ENABLE_TRANSLATION
if (translation_server != nullptr) {
auto response = TranslateLogin(alloc, translation_server,
"ssh"sv, listener_tag,
username, {});

if (response.status != HttpStatus{})
throw std::runtime_error{"Translation server failed"};

response.child_options.CopyTo(p);
if (const auto *tr = c.GetTranslationResponse()) {
tr->child_options.CopyTo(p);

if (response.shell != nullptr)
shell = response.shell;
if (tr->shell != nullptr)
shell = tr->shell;
} else {
#endif // ENABLE_TRANSLATION
// TODO
Expand Down
10 changes: 0 additions & 10 deletions src/SessionChannel.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "spawn/ExitListener.hxx"
#include "event/PipeEvent.hxx"
#include "io/UniqueFileDescriptor.hxx"
#include "config.h"

#include <forward_list>
#include <memory>
Expand All @@ -22,11 +21,6 @@ class SessionChannel final : public SSH::Channel, ExitListener

SpawnService &spawn_service;

#ifdef ENABLE_TRANSLATION
const char *const translation_server;
const std::string_view listener_tag;
#endif

std::unique_ptr<ChildProcessHandle> child;

UniqueFileDescriptor stdin_pipe, slave_tty;
Expand All @@ -41,10 +35,6 @@ class SessionChannel final : public SSH::Channel, ExitListener

public:
SessionChannel(SpawnService &_spawn_service,
#ifdef ENABLE_TRANSLATION
const char *_translation_server,
std::string_view _listener_tag,
#endif
SSH::CConnection &_connection,
SSH::ChannelInit init) noexcept;

Expand Down

0 comments on commit cdcc7e2

Please sign in to comment.