diff --git a/core/client.cpp b/core/client.cpp index 9637c3a..0f7f6af 100644 --- a/core/client.cpp +++ b/core/client.cpp @@ -7,6 +7,7 @@ #include "except/exceptions.hpp" #include "utils/logging.hpp" #include "syscalls/poll.h" +#include "globals.hpp" using namespace cerb; @@ -17,6 +18,7 @@ Client::Client(int fd, Proxy* p) : ProxyConnection(fd) , _proxy(p) , _awaiting_count(0) + , _auth(false) { p->poll_add_ro(this); } @@ -193,3 +195,13 @@ void Client::push_command(util::sptr g) { this->_parsed_groups.push_back(std::move(g)); } + +bool Client::is_client_auth() +{ + return this->_auth || !cerb_global::need_auth(); +} + +void Client::set_client_auth(bool ok) +{ + this->_auth = ok; +} diff --git a/core/client.hpp b/core/client.hpp index 89b5bc8..2bed965 100644 --- a/core/client.hpp +++ b/core/client.hpp @@ -25,6 +25,7 @@ namespace cerb { int _awaiting_count; Buffer _buffer; BufferSet _output_buffer_set; + bool _auth; void _process(); void _send_buffer_set(); @@ -41,6 +42,8 @@ namespace cerb { void add_peer(Server* svr); void reactivate(util::sref cmd); void push_command(util::sptr g); + bool is_client_auth(); + void set_client_auth(bool ok); }; } diff --git a/core/command.cpp b/core/command.cpp index 93daba5..6fae9fd 100644 --- a/core/command.cpp +++ b/core/command.cpp @@ -20,16 +20,22 @@ using namespace cerb; namespace { std::string const RSP_OK_STR("+OK\r\n"); + std::string const NOAUTH_RSP("-NOAUTH Authentication required.\r\n"); std::shared_ptr const RSP_OK(new Buffer(RSP_OK_STR)); Server* select_server_for(Proxy* proxy, DataCommand* cmd, slot key_slot) { + if (!cmd->group->client->is_client_auth()) { + return nullptr; + } + Server* svr = proxy->get_server_by_slot(key_slot); if (svr == nullptr) { LOG(DEBUG) << "Cluster slot not covered " << key_slot; proxy->retry_move_ask_command_later(util::mkref(*cmd)); return nullptr; } + svr->push_client_command(util::mkref(*cmd)); return svr; } @@ -329,6 +335,10 @@ namespace { util::sptr spawn_commands(util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->msg.empty()) { return util::mkptr(new DirectCommandGroup(c, "+PONG\r\n")); } @@ -342,6 +352,43 @@ namespace { } }; + class AuthCommandParser + : public SpecialCommandParser + { + std::string msg; + public: + AuthCommandParser() = default; + + util::sptr spawn_commands(util::sref c, Buffer::iterator) + { + // 先检查是否设置 proxy auth + if (!cerb_global::need_auth()) { + return util::mkptr(new DirectCommandGroup(c, "-ERR Client sent AUTH, but no password is set\r\n")); + } + + // passwd 为空 + if (this->msg.empty()) { + return util::mkptr(new DirectCommandGroup(c, "-ERR wrong number of arguments for 'auth' command\r\n")); + } + + // passwd 与 proxy 设置的不符 + if (!cerb_global::is_auth_ok(msg)) { + // 设置 client + c->set_client_auth(false); + return util::mkptr(new DirectCommandGroup(c, "-ERR invalid password\r\n")); + } + + // 设置 client + c->set_client_auth(true); + return util::mkptr(new DirectCommandGroup(c, "+OK\r\n")); + } + + void on_str(Buffer::iterator begin, Buffer::iterator end) + { + this->msg = std::string(begin, end); + } + }; + class ProxyStatsCommandParser : public SpecialCommandParser { @@ -351,6 +398,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + return util::mkptr(new DirectCommandGroup(c, stats_string())); } @@ -365,6 +416,10 @@ namespace { util::sptr spawn_commands(util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + ::notify_each_thread_update_slot_map(); return util::mkptr(new DirectCommandGroup(c, RSP_OK_STR)); } @@ -387,6 +442,10 @@ namespace { util::sptr spawn_commands(util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->bad) { return util::mkptr(new DirectCommandGroup( c, "-ERR invalid port number\r\n")); @@ -448,6 +507,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (keys_slots.empty()) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for '" + this->command_name + "' command\r\n")); @@ -566,6 +629,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (keys_slots.empty() || !current_is_key) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for 'mset' command\r\n")); @@ -680,6 +747,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (slot_index != 2 || this->bad) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for 'rename' command\r\n")); @@ -741,6 +812,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator end) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->no_arg) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for 'subscribe' command\r\n")); @@ -796,6 +871,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator end) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->args_count != 2) { return util::mkptr(new DirectCommandGroup( c, "-ERR BLPOP/BRPOP takes exactly 2 arguments KEY TIMEOUT in proxy\r\n")); @@ -840,6 +919,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator end) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->arg_count < 3 || this->key_count != 1) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for 'eval' command\r\n")); @@ -868,6 +951,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator end) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->arg_count != 2) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong number of arguments for 'publish' command\r\n")); @@ -901,6 +988,10 @@ namespace { util::sptr spawn_commands( util::sref c, Buffer::iterator end) { + if (!c->is_client_auth()) { + return util::mkptr(new DirectCommandGroup(c, NOAUTH_RSP)); + } + if (this->_arg_count != 2 || this->_slot >= CLUSTER_SLOT_COUNT) { return util::mkptr(new DirectCommandGroup( c, "-ERR wrong arguments for 'keysinslot' command\r\n")); @@ -920,6 +1011,11 @@ namespace { { return util::mkptr(new PingCommandParser); }}, + {"AUTH", + [](Buffer::iterator, Buffer::iterator) -> CmdPtr + { + return util::mkptr(new AuthCommandParser); + }}, {"INFO", [](Buffer::iterator, Buffer::iterator) -> CmdPtr { @@ -1059,16 +1155,31 @@ namespace { void on_split_point(Iterator i) { this->_on_str = ClientCommandSplitter::on_command_head; + if (this->last_command_is_bad) { this->client->push_command(util::mkptr(new DirectCommandGroup( client, "-ERR Unknown command or command key not specified\r\n"))); - } else if (this->special_parser.nul()) { - this->client->push_command(util::mkptr(new SingleCommandGroup( - client, Buffer(this->last_command_begin, i), this->slot_calc.get_slot()))); - } else { - this->client->push_command(this->special_parser->spawn_commands(this->client, i)); - this->special_parser.reset(); } + + if (! this->last_command_is_bad && this->client->is_client_auth()) { + if (this->special_parser.nul()) { + this->client->push_command(util::mkptr(new SingleCommandGroup( + client, Buffer(this->last_command_begin, i), this->slot_calc.get_slot()))); + } else { + this->client->push_command(this->special_parser->spawn_commands(this->client, i)); + this->special_parser.reset(); + } + } + + if (! this->last_command_is_bad && ! this->client->is_client_auth()) { + if (this->special_parser.nul()) { + this->client->push_command(util::mkptr(new DirectCommandGroup(this->client, NOAUTH_RSP))); + } else { + this->client->push_command(this->special_parser->spawn_commands(this->client, i)); + this->special_parser.reset(); + } + } + this->last_command_begin = i; this->slot_calc.reset(); this->last_command_is_bad = false; diff --git a/core/globals.cpp b/core/globals.cpp index 7f895f4..ffded96 100644 --- a/core/globals.cpp +++ b/core/globals.cpp @@ -46,3 +46,20 @@ bool cerb_global::cluster_ok() { return ::cluster_ok; } + +static std::string auth_pass(""); + +void cerb_global::set_auth_pass(std::string const& pass) +{ + ::auth_pass = pass; +} + +bool cerb_global::need_auth() +{ + return ::auth_pass != ""; +} + +bool cerb_global::is_auth_ok(std::string const& pass) +{ + return ::auth_pass == pass; +} diff --git a/core/globals.hpp b/core/globals.hpp index f4593e6..04ac60d 100644 --- a/core/globals.hpp +++ b/core/globals.hpp @@ -26,6 +26,9 @@ namespace cerb_global { void set_cluster_ok(bool ok); bool cluster_ok(); + void set_auth_pass(std::string const& pass); + bool need_auth(); + bool is_auth_ok(std::string const& pass); } #endif /* __CERBERUS_GLOBALS_HPP__ */ diff --git a/main.cpp b/main.cpp index c4e4f10..a5a4cbe 100644 --- a/main.cpp +++ b/main.cpp @@ -117,6 +117,11 @@ namespace { cerb_global::set_cluster_req_full_cov(false); } + if (config.get("auth", "") != "") { + LOG(INFO) << "Proxy set need auth"; + cerb_global::set_auth_pass(config.get("auth")); + } + int slow_poll_ms = util::atoi(config.get("slow-poll-elapse-ms", "50")); if (slow_poll_ms <= 0) { LOG(ERROR) << "Invalid slow poll elapse"; diff --git a/utils/address.cpp b/utils/address.cpp index 0c40cda..1a1ebf3 100644 --- a/utils/address.cpp +++ b/utils/address.cpp @@ -11,7 +11,9 @@ Address Address::from_host_port(std::string const& addr) if (host_port.size() != 2) { throw std::runtime_error("Invalid address: " + addr); } - return Address(host_port[0], util::atoi(host_port[1].data())); + + std::vector ports(util::split_str(host_port[1], "@")); + return Address(host_port[0], util::atoi(ports[0].data())); } std::set Address::from_hosts_ports(std::string const& addrs)