diff --git a/src/Connection.cxx b/src/Connection.cxx index a427960..15bd695 100644 --- a/src/Connection.cxx +++ b/src/Connection.cxx @@ -12,10 +12,30 @@ #include "ssh/Channel.hxx" #include "net/UniqueSocketDescriptor.hxx" +#ifdef ENABLE_TRANSLATION +#include "translation/LoginGlue.hxx" +#include "translation/Response.hxx" +#include "AllocatorPtr.hxx" +#endif // ENABLE_TRANSLATION + #include using std::string_view_literals::operator""sv; +#ifdef ENABLE_TRANSLATION + +struct Connection::Translation { + Allocator alloc; + TranslateResponse response; + + Translation(Allocator &&_alloc, + TranslateResponse &&_response) noexcept + :alloc(std::move(_alloc)), + response(std::move(_response)) {} +}; + +#endif // ENABLE_TRANSLATION + Connection::Connection(Instance &_instance, Listener &_listener, UniqueSocketDescriptor _fd, const KeyList &_host_keys) @@ -28,6 +48,18 @@ Connection::Connection(Instance &_instance, Listener &_listener, Connection::~Connection() noexcept = default; +#ifdef ENABLE_TRANSLATION + +const TranslateResponse * +Connection::GetTranslationResponse() const noexcept +{ + return translation + ? &translation->response + : nullptr; +} + +#endif + std::unique_ptr Connection::OpenChannel(std::string_view channel_type, SSH::ChannelInit init) @@ -38,10 +70,6 @@ Connection::OpenChannel(std::string_view channel_type, if (channel_type == "session"sv) { CConnection &connection = *this; return std::make_unique(instance.GetSpawnService(), -#ifdef ENABLE_TRANSLATION - instance.GetTranslationServer(), - listener.GetTag(), -#endif connection, init); } else return SSH::CConnection::OpenChannel(channel_type, init); @@ -69,6 +97,21 @@ Connection::HandleUserauthRequest(std::span payload) const auto new_username = d.ReadString(); fmt::print(stderr, "Userauth '{}'\n", new_username); +#ifdef ENABLE_TRANSLATION + if (const char *translation_server = instance.GetTranslationServer()) { + Allocator alloc; + auto response = TranslateLogin(alloc, translation_server, + "ssh"sv, listener.GetTag(), + new_username, {}); + + if (response.status != HttpStatus{}) + throw std::runtime_error{"Translation server failed"}; + + translation = std::make_unique(std::move(alloc), + std::move(response)); + } +#endif // ENABLE_TRANSLATION + username.assign(new_username); SendPacket(SSH::PacketSerializer{SSH::MessageNumber::USERAUTH_SUCCESS}); diff --git a/src/Connection.hxx b/src/Connection.hxx index d64f8f5..69cbf6f 100644 --- a/src/Connection.hxx +++ b/src/Connection.hxx @@ -6,7 +6,11 @@ #include "ssh/CConnection.hxx" #include "util/IntrusiveList.hxx" +#include "config.h" +#include + +struct TranslateResponse; class Instance; class Listener; class RootLogger; @@ -25,6 +29,11 @@ class Connection final std::string username; +#ifdef ENABLE_TRANSLATION + struct Translation; + std::unique_ptr translation; +#endif // ENABLE_TRANSLATION + public: Connection(Instance &_instance, Listener &_listener, UniqueSocketDescriptor fd, @@ -39,6 +48,11 @@ public: return username; } +#ifdef ENABLE_TRANSLATION + [[gnu::pure]] + const TranslateResponse *GetTranslationResponse() const noexcept; +#endif + protected: void Destroy() noexcept override { delete this; diff --git a/src/SessionChannel.cxx b/src/SessionChannel.cxx index 9ca9708..18c600c 100644 --- a/src/SessionChannel.cxx +++ b/src/SessionChannel.cxx @@ -15,7 +15,6 @@ #include "AllocatorPtr.hxx" #ifdef ENABLE_TRANSLATION -#include "translation/LoginGlue.hxx" #include "translation/Response.hxx" #endif // ENABLE_TRANSLATION @@ -30,18 +29,10 @@ using std::string_view_literals::operator""sv; SessionChannel::SessionChannel(SpawnService &_spawn_service, -#ifdef ENABLE_TRANSLATION - const char *_translation_server, - std::string_view _listener_tag, -#endif SSH::CConnection &_connection, SSH::ChannelInit init) noexcept :SSH::Channel(_connection, init, RECEIVE_WINDOW), spawn_service(_spawn_service), -#ifdef ENABLE_TRANSLATION - translation_server(_translation_server), - listener_tag(_listener_tag), -#endif stdout_pipe(_connection.GetEventLoop(), BIND_THIS_METHOD(OnStdoutReady)), stderr_pipe(_connection.GetEventLoop(), BIND_THIS_METHOD(OnStderrReady)), tty(_connection.GetEventLoop(), BIND_THIS_METHOD(OnTtyReady)) @@ -127,18 +118,11 @@ SessionChannel::Exec(const char *cmd) const char *shell = cmd != nullptr ? "/bin/sh" : "/bin/bash"; #ifdef ENABLE_TRANSLATION - if (translation_server != nullptr) { - auto response = TranslateLogin(alloc, translation_server, - "ssh"sv, listener_tag, - username, {}); - - if (response.status != HttpStatus{}) - throw std::runtime_error{"Translation server failed"}; - - response.child_options.CopyTo(p); + if (const auto *tr = c.GetTranslationResponse()) { + tr->child_options.CopyTo(p); - if (response.shell != nullptr) - shell = response.shell; + if (tr->shell != nullptr) + shell = tr->shell; } else { #endif // ENABLE_TRANSLATION // TODO diff --git a/src/SessionChannel.hxx b/src/SessionChannel.hxx index aad45d1..fc40600 100644 --- a/src/SessionChannel.hxx +++ b/src/SessionChannel.hxx @@ -8,7 +8,6 @@ #include "spawn/ExitListener.hxx" #include "event/PipeEvent.hxx" #include "io/UniqueFileDescriptor.hxx" -#include "config.h" #include #include @@ -22,11 +21,6 @@ class SessionChannel final : public SSH::Channel, ExitListener SpawnService &spawn_service; -#ifdef ENABLE_TRANSLATION - const char *const translation_server; - const std::string_view listener_tag; -#endif - std::unique_ptr child; UniqueFileDescriptor stdin_pipe, slave_tty; @@ -41,10 +35,6 @@ class SessionChannel final : public SSH::Channel, ExitListener public: SessionChannel(SpawnService &_spawn_service, -#ifdef ENABLE_TRANSLATION - const char *_translation_server, - std::string_view _listener_tag, -#endif SSH::CConnection &_connection, SSH::ChannelInit init) noexcept;