diff --git a/src/connection.cpp b/src/connection.cpp index cf6fa26527..b5ea5e2499 100644 --- a/src/connection.cpp +++ b/src/connection.cpp @@ -11,170 +11,153 @@ #include "server.h" #include "tasks.h" -Connection_ptr ConnectionManager::createConnection(boost::asio::io_context& io_context, - ConstServicePort_ptr servicePort) -{ - std::lock_guard lockClass(connectionManagerLock); +Connection::Connection(asio::io_context& ioc, std::shared_ptr service_port) : + socket_read_timer(ioc), + socket_write_timer(ioc), + service_port(std::move(service_port)), + socket(ioc), + timeConnected(time(nullptr)) +{} - auto connection = std::make_shared(io_context, servicePort); - connections.insert(connection); - return connection; -} +Connection::~Connection() { close_socket(); } -void ConnectionManager::releaseConnection(const Connection_ptr& connection) +void Connection::accept(Protocol_ptr protocol) { - std::lock_guard lockClass(connectionManagerLock); + std::lock_guard lock(connection_lock); - connections.erase(connection); -} + if (protocol) { + this->protocol = protocol; + g_dispatcher.addTask([=]() { protocol->onConnect(); }); -void ConnectionManager::closeAll() -{ - std::lock_guard lockClass(connectionManagerLock); + state = ConnectionState::GameWorldAuthentication; + } else if (state == ConnectionState::Pending) { + state = ConnectionState::RequestCharacterList; + } - for (const auto& connection : connections) { - try { - boost::system::error_code error; - connection->socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, error); - connection->socket.close(error); - } catch (boost::system::system_error&) { - } + system::error_code error; + if (auto endpoint = socket.remote_endpoint(error); !error) { + address = endpoint.address(); } - connections.clear(); -} -// Connection + try { + socket_read_timer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); + socket_read_timer.async_wait( + [thisPtr = std::weak_ptr(shared_from_this())](const system::error_code& error) { + Connection::handle_socket_timeout(thisPtr, error); + }); -Connection::Connection(boost::asio::io_context& io_context, ConstServicePort_ptr service_port) : - readTimer(io_context), - writeTimer(io_context), - service_port(std::move(service_port)), - socket(io_context), - timeConnected(time(nullptr)) -{} + // Read size of the first packet + auto bufferLength = !receivedLastChar && receivedName && state == ConnectionState::GameWorldAuthentication + ? 1 + : NetworkMessage::HEADER_LENGTH; -void Connection::close(bool force) + asio::async_read(socket, asio::buffer(msg.getBuffer(), bufferLength), + [thisPtr = shared_from_this()](const system::error_code& error, auto /*bytes_transferred*/) { + thisPtr->parse_packet_header(error); + }); + } catch (system::system_error& e) { + std::cout << "[Network error - Connection::accept] " << e.what() << std::endl; + disconnect_and_close_socket(); + } +} + +void Connection::disconnect() { // any thread - ConnectionManager::getInstance().releaseConnection(shared_from_this()); + tfs::net::disconnect(shared_from_this()); + + std::lock_guard lock(connection_lock); - std::lock_guard lockClass(connectionLock); - connectionState = CONNECTION_STATE_DISCONNECTED; + state = ConnectionState::Disconnected; if (protocol) { g_dispatcher.addTask([protocol = protocol]() { protocol->release(); }); } - if (messageQueue.empty() || force) { - closeSocket(); + if (server_messages.empty()) { + close_socket(); } else { - // will be closed by the destructor or onWriteOperation + // will be closed by the destructor or on_write_to_socket } } -void Connection::closeSocket() +void Connection::close_socket() { if (socket.is_open()) { try { - readTimer.cancel(); - writeTimer.cancel(); - boost::system::error_code error; - socket.shutdown(boost::asio::ip::tcp::socket::shutdown_both, error); + socket_read_timer.cancel(); + socket_write_timer.cancel(); + + system::error_code error; + socket.shutdown(asio::ip::tcp::socket::shutdown_both, error); socket.close(error); - } catch (boost::system::system_error& e) { - std::cout << "[Network error - Connection::closeSocket] " << e.what() << std::endl; + } catch (system::system_error& e) { + std::cout << "[Network error - Connection::close_socket] " << e.what() << std::endl; } } } -Connection::~Connection() { closeSocket(); } - -void Connection::accept(Protocol_ptr protocol) +void Connection::disconnect_and_close_socket() { - this->protocol = protocol; - g_dispatcher.addTask([=]() { protocol->onConnect(); }); - connectionState = CONNECTION_STATE_GAMEWORLD_AUTH; - accept(); + disconnect(); + close_socket(); } -void Connection::accept() +void Connection::send_message(const OutputMessage_ptr& message) { - if (connectionState == CONNECTION_STATE_PENDING) { - connectionState = CONNECTION_STATE_REQUEST_CHARLIST; - } + std::lock_guard lock(connection_lock); - std::lock_guard lockClass(connectionLock); - - boost::system::error_code error; - if (auto endpoint = socket.remote_endpoint(error); !error) { - remoteAddress = endpoint.address(); + if (state == ConnectionState::Disconnected) { + return; } - try { - readTimer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); - readTimer.async_wait( - [thisPtr = std::weak_ptr(shared_from_this())](const boost::system::error_code& error) { - Connection::handleTimeout(thisPtr, error); - }); - - // Read size of the first packet - auto bufferLength = !receivedLastChar && receivedName && connectionState == CONNECTION_STATE_GAMEWORLD_AUTH - ? 1 - : NetworkMessage::HEADER_LENGTH; - boost::asio::async_read( - socket, boost::asio::buffer(msg.getBuffer(), bufferLength), - [thisPtr = shared_from_this()](const boost::system::error_code& error, auto /*bytes_transferred*/) { - thisPtr->parseHeader(error); - }); - } catch (boost::system::system_error& e) { - std::cout << "[Network error - Connection::accept] " << e.what() << std::endl; - close(FORCE_CLOSE); + bool noPendingWrite = server_messages.empty(); + server_messages.emplace_back(message); + if (noPendingWrite) { + send_message_to_socket(message); } } -void Connection::parseHeader(const boost::system::error_code& error) +void Connection::parse_packet_header(const system::error_code& error_on_read) { - std::lock_guard lockClass(connectionLock); - readTimer.cancel(); + std::lock_guard lock(connection_lock); - if (error) { - close(FORCE_CLOSE); + socket_read_timer.cancel(); + + if (error_on_read) { + disconnect_and_close_socket(); return; - } else if (connectionState == CONNECTION_STATE_DISCONNECTED) { + } + + if (state == ConnectionState::Disconnected) { return; } uint32_t timePassed = std::max(1, (time(nullptr) - timeConnected) + 1); if ((++packetsSent / timePassed) > static_cast(getNumber(ConfigManager::MAX_PACKETS_PER_SECOND))) { - std::cout << getIP() << " disconnected for exceeding packet per second limit." << std::endl; - close(); + std::cout << socket_address() << " disconnected for exceeding packet per second limit." << std::endl; + disconnect(); return; } - if (!receivedLastChar && connectionState == CONNECTION_STATE_GAMEWORLD_AUTH) { + if (!receivedLastChar && state == ConnectionState::GameWorldAuthentication) { uint8_t* msgBuffer = msg.getBuffer(); if (!receivedName && msgBuffer[1] == 0x00) { receivedLastChar = true; - } else { - if (!receivedName) { - receivedName = true; - - accept(); - return; - } - - if (msgBuffer[0] == 0x0A) { - receivedLastChar = true; - } - + } else if (!receivedName) { + receivedName = true; + accept(); + return; + } else if (msgBuffer[0] == 0x0A) { + receivedLastChar = true; accept(); return; } } - if (receivedLastChar && connectionState == CONNECTION_STATE_GAMEWORLD_AUTH) { - connectionState = CONNECTION_STATE_GAME; + if (receivedLastChar && state == ConnectionState::GameWorldAuthentication) { + state = ConnectionState::Game; } if (timePassed > 2) { @@ -184,39 +167,42 @@ void Connection::parseHeader(const boost::system::error_code& error) uint16_t size = msg.getLengthHeader(); if (size == 0 || size >= NETWORKMESSAGE_MAXSIZE - 16) { - close(FORCE_CLOSE); + disconnect_and_close_socket(); return; } try { - readTimer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); - readTimer.async_wait( - [thisPtr = std::weak_ptr(shared_from_this())](const boost::system::error_code& error) { - Connection::handleTimeout(thisPtr, error); + socket_read_timer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); + socket_read_timer.async_wait( + [thisPtr = std::weak_ptr(shared_from_this())](const system::error_code& error) { + Connection::handle_socket_timeout(thisPtr, error); }); // Read packet content msg.setLength(size + NetworkMessage::HEADER_LENGTH); - boost::asio::async_read( - socket, boost::asio::buffer(msg.getBodyBuffer(), size), - [thisPtr = shared_from_this()](const boost::system::error_code& error, auto /*bytes_transferred*/) { - thisPtr->parsePacket(error); - }); - } catch (boost::system::system_error& e) { + + asio::async_read(socket, asio::buffer(msg.getBodyBuffer(), size), + [thisPtr = shared_from_this()](const system::error_code& error, auto /*bytes_transferred*/) { + thisPtr->parse_packet_body(error); + }); + } catch (system::system_error& e) { std::cout << "[Network error - Connection::parseHeader] " << e.what() << std::endl; - close(FORCE_CLOSE); + disconnect_and_close_socket(); } } -void Connection::parsePacket(const boost::system::error_code& error) +void Connection::parse_packet_body(const system::error_code& error) { - std::lock_guard lockClass(connectionLock); - readTimer.cancel(); + std::lock_guard lock(connection_lock); + + socket_read_timer.cancel(); if (error) { - close(FORCE_CLOSE); + disconnect_and_close_socket(); return; - } else if (connectionState == CONNECTION_STATE_DISCONNECTED) { + } + + if (state == ConnectionState::Disconnected) { return; } @@ -236,7 +222,8 @@ void Connection::parsePacket(const boost::system::error_code& error) // Game protocol has already been created at this point protocol = service_port->make_protocol(msg, shared_from_this()); if (!protocol) { - close(FORCE_CLOSE); + disconnect(); + close_socket(); return; } } else { @@ -249,86 +236,151 @@ void Connection::parsePacket(const boost::system::error_code& error) } try { - readTimer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); - readTimer.async_wait( - [thisPtr = std::weak_ptr(shared_from_this())](const boost::system::error_code& error) { - Connection::handleTimeout(thisPtr, error); + socket_read_timer.expires_after(std::chrono::seconds(CONNECTION_READ_TIMEOUT)); + socket_read_timer.async_wait( + [thisPtr = std::weak_ptr(shared_from_this())](const system::error_code& error) { + Connection::handle_socket_timeout(thisPtr, error); }); // Wait to the next packet - boost::asio::async_read( - socket, boost::asio::buffer(msg.getBuffer(), NetworkMessage::HEADER_LENGTH), - [thisPtr = shared_from_this()](const boost::system::error_code& error, auto /*bytes_transferred*/) { - thisPtr->parseHeader(error); - }); - } catch (boost::system::system_error& e) { + asio::async_read(socket, asio::buffer(msg.getBuffer(), NetworkMessage::HEADER_LENGTH), + [thisPtr = shared_from_this()](const system::error_code& error, auto /*bytes_transferred*/) { + thisPtr->parse_packet_header(error); + }); + } catch (system::system_error& e) { std::cout << "[Network error - Connection::parsePacket] " << e.what() << std::endl; - close(FORCE_CLOSE); - } -} - -void Connection::send(const OutputMessage_ptr& msg) -{ - std::lock_guard lockClass(connectionLock); - if (connectionState == CONNECTION_STATE_DISCONNECTED) { - return; - } - - bool noPendingWrite = messageQueue.empty(); - messageQueue.emplace_back(msg); - if (noPendingWrite) { - internalSend(msg); + disconnect_and_close_socket(); } } -void Connection::internalSend(const OutputMessage_ptr& msg) +void Connection::send_message_to_socket(const OutputMessage_ptr& msg) { protocol->onSendMessage(msg); + try { - writeTimer.expires_after(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT)); - writeTimer.async_wait( - [thisPtr = std::weak_ptr(shared_from_this())](const boost::system::error_code& error) { - Connection::handleTimeout(thisPtr, error); + socket_write_timer.expires_after(std::chrono::seconds(CONNECTION_WRITE_TIMEOUT)); + socket_write_timer.async_wait( + [thisPtr = std::weak_ptr(shared_from_this())](const system::error_code& error) { + Connection::handle_socket_timeout(thisPtr, error); }); - boost::asio::async_write( - socket, boost::asio::buffer(msg->getOutputBuffer(), msg->getLength()), - [thisPtr = shared_from_this()](const boost::system::error_code& error, auto /*bytes_transferred*/) { - thisPtr->onWriteOperation(error); - }); - } catch (boost::system::system_error& e) { - std::cout << "[Network error - Connection::internalSend] " << e.what() << std::endl; - close(FORCE_CLOSE); + asio::async_write(socket, asio::buffer(msg->getOutputBuffer(), msg->getLength()), + [thisPtr = shared_from_this()](const system::error_code& error, auto /*bytes_transferred*/) { + thisPtr->on_write_to_socket(error); + }); + } catch (system::system_error& e) { + std::cout << "[Network error - Connection::send_message_to_socket] " << e.what() << std::endl; + disconnect_and_close_socket(); } } -void Connection::onWriteOperation(const boost::system::error_code& error) +void Connection::on_write_to_socket(const system::error_code& error) { - std::lock_guard lockClass(connectionLock); - writeTimer.cancel(); - messageQueue.pop_front(); + std::lock_guard lock(connection_lock); + socket_write_timer.cancel(); + + server_messages.pop_front(); if (error) { - messageQueue.clear(); - close(FORCE_CLOSE); + server_messages.clear(); + disconnect_and_close_socket(); return; } - if (!messageQueue.empty()) { - internalSend(messageQueue.front()); - } else if (connectionState == CONNECTION_STATE_DISCONNECTED) { - closeSocket(); + if (!server_messages.empty()) { + send_message_to_socket(server_messages.front()); + } else if (state == ConnectionState::Disconnected) { + close_socket(); } } -void Connection::handleTimeout(ConnectionWeak_ptr connectionWeak, const boost::system::error_code& error) +void Connection::handle_socket_timeout(ConnectionWeak_ptr connection_weak, const system::error_code& error) { - if (error == boost::asio::error::operation_aborted) { + if (error == asio::error::operation_aborted) { // The timer has been cancelled manually return; } - if (auto connection = connectionWeak.lock()) { - connection->close(FORCE_CLOSE); + if (auto connection = connection_weak.lock()) { + connection->disconnect_and_close_socket(); + } +} + +namespace { + +std::unordered_set connections; +std::mutex connections_lock; + +struct ConnectionBlock +{ + uint64_t last_attempt; + uint64_t block_time = 0; + uint32_t count = 1; +}; + +std::map connections_block; +std::recursive_mutex connections_block_lock; + +} // namespace + +Connection_ptr tfs::net::create_connection(asio::io_context ioc, std::shared_ptr service_port) +{ + std::lock_guard lock(connections_lock); + + auto connection = std::make_shared(ioc, service_port); + connections.insert(connection); + return connection; +} + +void tfs::net::disconnect(const Connection_ptr& connection) +{ + std::lock_guard lock(connections_lock); + + connections.erase(connection); +} + +void tfs::net::disconnect_all() +{ + std::lock_guard lock(connections_lock); + + for (const auto& connection : connections) { + connection->close_socket(); + } + + connections.clear(); +} + +bool tfs::net::has_connection_blocked(const Connection::SocketAddress& socket_address) +{ + std::lock_guard lock{connections_block_lock}; + + uint64_t current_time = OTSYS_TIME(); + + auto it = connections_block.find(socket_address); + if (it == connections_block.end()) { + connections_block.emplace(socket_address, ConnectionBlock{.last_attempt = current_time}); + return false; + } + + auto& connection_block = it->second; + if (connection_block.block_time > current_time) { + connection_block.block_time += 250; + return true; + } + + int64_t time_diff = current_time - connection_block.last_attempt; + connection_block.last_attempt = current_time; + + if (time_diff <= 5000) { + if (++connection_block.count > 5) { + connection_block.count = 0; + if (time_diff <= 500) { + connection_block.block_time = current_time + 3000; + return true; + } + } + } else { + connection_block.count = 1; } + return false; } diff --git a/src/connection.h b/src/connection.h index b2e076f282..6c78aff42c 100644 --- a/src/connection.h +++ b/src/connection.h @@ -6,120 +6,98 @@ #include "networkmessage.h" -enum ConnectionState_t -{ - CONNECTION_STATE_DISCONNECTED, - CONNECTION_STATE_REQUEST_CHARLIST, - CONNECTION_STATE_GAMEWORLD_AUTH, - CONNECTION_STATE_GAME, - CONNECTION_STATE_PENDING -}; - -enum checksumMode_t -{ - CHECKSUM_DISABLED, - CHECKSUM_ADLER, - CHECKSUM_SEQUENCE -}; - static constexpr int32_t CONNECTION_WRITE_TIMEOUT = 30; static constexpr int32_t CONNECTION_READ_TIMEOUT = 30; class Protocol; -using Protocol_ptr = std::shared_ptr; class OutputMessage; -using OutputMessage_ptr = std::shared_ptr; class Connection; +class ServiceBase; +class ServicePort; + +using Protocol_ptr = std::shared_ptr; +using OutputMessage_ptr = std::shared_ptr; using Connection_ptr = std::shared_ptr; using ConnectionWeak_ptr = std::weak_ptr; -class ServiceBase; using Service_ptr = std::shared_ptr; -class ServicePort; using ServicePort_ptr = std::shared_ptr; -using ConstServicePort_ptr = std::shared_ptr; -class ConnectionManager -{ -public: - static ConnectionManager& getInstance() - { - static ConnectionManager instance; - return instance; - } - - Connection_ptr createConnection(boost::asio::io_context& io_context, ConstServicePort_ptr servicePort); - void releaseConnection(const Connection_ptr& connection); - void closeAll(); - -private: - ConnectionManager() = default; - - std::unordered_set connections; - std::mutex connectionManagerLock; -}; +namespace asio = boost::asio; +namespace system = boost::system; class Connection : public std::enable_shared_from_this { public: - using Address = boost::asio::ip::address; - // non-copyable - Connection(const Connection&) = delete; - Connection& operator=(const Connection&) = delete; + using SocketAddress = boost::asio::ip::address; - enum + enum ConnectionState { - FORCE_CLOSE = true + Pending, + RequestCharacterList, + GameWorldAuthentication, + Game, + Disconnected, }; - Connection(boost::asio::io_context& io_context, ConstServicePort_ptr service_port); - ~Connection(); + enum ChecksumMode + { + Disabled, + Adler, + Sequence + }; - friend class ConnectionManager; + Connection(asio::io_context& ioc, std::shared_ptr service_port); + ~Connection(); - void close(bool force = false); - // Used by protocols that require server to send first - void accept(Protocol_ptr protocol); - void accept(); + // non-copyable + Connection(const Connection&) = delete; + Connection& operator=(const Connection&) = delete; - void send(const OutputMessage_ptr& msg); + void accept(Protocol_ptr protocol = nullptr); + void disconnect(); + void close_socket(); + void disconnect_and_close_socket(); + void send_message(const OutputMessage_ptr& msg); - const Address& getIP() const { return remoteAddress; }; + const SocketAddress& socket_address() const { return address; }; private: - void parseHeader(const boost::system::error_code& error); - void parsePacket(const boost::system::error_code& error); - - void onWriteOperation(const boost::system::error_code& error); - - static void handleTimeout(ConnectionWeak_ptr connectionWeak, const boost::system::error_code& error); + asio::ip::tcp::socket& getSocket() { return socket; } + void parse_packet_header(const system::error_code& error); + void parse_packet_body(const system::error_code& error); - void closeSocket(); - void internalSend(const OutputMessage_ptr& msg); + void send_message_to_socket(const OutputMessage_ptr& msg); + void on_write_to_socket(const system::error_code& error); + static void handle_socket_timeout(ConnectionWeak_ptr connection_weak, const system::error_code& error); - boost::asio::ip::tcp::socket& getSocket() { return socket; } + std::shared_ptr service_port; friend class ServicePort; NetworkMessage msg; - - boost::asio::steady_timer readTimer; - boost::asio::steady_timer writeTimer; - - std::recursive_mutex connectionLock; - - std::list messageQueue; - - ConstServicePort_ptr service_port; Protocol_ptr protocol; - - boost::asio::ip::tcp::socket socket; - Address remoteAddress; + std::list server_messages; + ConnectionState state = ConnectionState::Pending; time_t timeConnected; - uint32_t packetsSent = 0; - - ConnectionState_t connectionState = CONNECTION_STATE_PENDING; + uint32_t packets_sent = 0; bool receivedFirst = false; bool receivedName = false; bool receivedLastChar = false; + + asio::steady_timer socket_read_timer; + asio::steady_timer socket_write_timer; + asio::ip::tcp::socket socket; + asio::ip::address address; + + std::recursive_mutex connection_lock; }; +namespace tfs::net { + +Connection_ptr create_connection(asio::io_context ioc, std::shared_ptr service_port); +void disconnect(const Connection_ptr& connection); +void disconnect_all(); +bool has_connection_blocked(const Connection::SocketAddress& socket_address); + +} // namespace tfs::net + #endif // FS_CONNECTION_H diff --git a/src/server.cpp b/src/server.cpp index b7ce9e7b44..4bf50e0365 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -12,55 +12,12 @@ namespace { -struct ConnectBlock -{ - uint64_t lastAttempt; - uint64_t blockTime = 0; - uint32_t count = 1; -}; - -bool acceptConnection(const Connection::Address& clientIP) -{ - static std::recursive_mutex mu; - std::lock_guard lock{mu}; - - uint64_t currentTime = OTSYS_TIME(); - - static std::map ipConnectMap; - auto it = ipConnectMap.find(clientIP); - if (it == ipConnectMap.end()) { - ipConnectMap.emplace(clientIP, ConnectBlock{.lastAttempt = currentTime}); - return true; - } - - ConnectBlock& connectBlock = it->second; - if (connectBlock.blockTime > currentTime) { - connectBlock.blockTime += 250; - return false; - } - - int64_t timeDiff = currentTime - connectBlock.lastAttempt; - connectBlock.lastAttempt = currentTime; - if (timeDiff <= 5000) { - if (++connectBlock.count > 5) { - connectBlock.count = 0; - if (timeDiff <= 500) { - connectBlock.blockTime = currentTime + 3000; - return false; - } - } - } else { - connectBlock.count = 1; - } - return true; -} - -boost::asio::ip::address getListenAddress() +Connection::SocketAddress getListenAddress() { if (getBoolean(ConfigManager::BIND_ONLY_GLOBAL_ADDRESS)) { - return boost::asio::ip::make_address(getString(ConfigManager::IP)); + return asio::ip::make_address(getString(ConfigManager::IP)); } - return boost::asio::ip::address_v6::any(); + return asio::ip::address_v6::any(); } void openAcceptor(std::weak_ptr weak_service, uint16_t port) @@ -93,8 +50,8 @@ void ServiceManager::stop() for (auto& servicePortIt : acceptors) { try { - boost::asio::post(io_context, [servicePort = servicePortIt.second]() { servicePort->onStopServer(); }); - } catch (boost::system::system_error& e) { + asio::post(io_context, [servicePort = servicePortIt.second]() { servicePort->onStopServer(); }); + } catch (system::system_error& e) { std::cout << "[ServiceManager::stop] Network Error: " << e.what() << std::endl; } } @@ -102,7 +59,7 @@ void ServiceManager::stop() acceptors.clear(); death_timer.expires_after(std::chrono::seconds(3)); - death_timer.async_wait([this](const boost::system::error_code&) { die(); }); + death_timer.async_wait([this](const system::error_code&) { die(); }); } ServicePort::~ServicePort() { close(); } @@ -130,50 +87,58 @@ void ServicePort::accept() return; } - auto connection = ConnectionManager::getInstance().createConnection(io_context, shared_from_this()); - acceptor->async_accept(connection->getSocket(), - [=, thisPtr = shared_from_this()](const boost::system::error_code& error) { - thisPtr->onAccept(connection, error); - }); + auto connection = tfs::net::create_connection(io_context, shared_from_this()); + acceptor->async_accept(connection->getSocket(), [=, thisPtr = shared_from_this()](const system::error_code& error) { + thisPtr->onAccept(connection, error); + }); } -void ServicePort::onAccept(Connection_ptr connection, const boost::system::error_code& error) +void ServicePort::onAccept(Connection_ptr connection, const system::error_code& error) { - if (!error) { - if (services.empty()) { - return; - } - - const auto& remote_ip = connection->getIP(); - if (acceptConnection(remote_ip)) { - Service_ptr service = services.front(); - if (service->is_single_socket()) { - connection->accept(service->make_protocol(connection)); - } else { - connection->accept(); - } - } else { - connection->close(Connection::FORCE_CLOSE); - } + if (error == asio::error::operation_aborted) { + return; + } - accept(); - } else if (error != boost::asio::error::operation_aborted) { + if (error != asio::error::operation_aborted) { if (!pendingStart) { close(); + pendingStart = true; + g_scheduler.addEvent(createSchedulerTask( 15000, [serverPort = this->serverPort, service = std::weak_ptr(shared_from_this())]() { openAcceptor(service, serverPort); })); } + return; } + + if (services.empty()) { + return; + } + + const auto& socket_address = connection->socket_address(); + if (tfs::net::has_connection_blocked(socket_address)) { + connection->disconnect_and_close_socket(); + accept(); + return; + } + + auto& service = services.front(); + if (service->is_single_socket()) { + connection->accept(service->make_protocol(connection)); + } else { + connection->accept(); + } + + accept(); } Protocol_ptr ServicePort::make_protocol(NetworkMessage& msg, const Connection_ptr& connection) const { - uint8_t protocolID = msg.getByte(); + auto protocol_id = msg.getByte(); for (auto& service : services) { - if (protocolID != service->get_protocol_identifier()) { + if (protocol_id != service->get_protocol_identifier()) { continue; } return service->make_protocol(connection); @@ -185,7 +150,7 @@ void ServicePort::onStopServer() { close(); } void ServicePort::open(uint16_t port) { - namespace ip = boost::asio::ip; + namespace ip = asio::ip; close(); @@ -200,7 +165,7 @@ void ServicePort::open(uint16_t port) ip::v6_only option; acceptor->get_option(option); if (option) { - boost::system::error_code err; + system::error_code err; acceptor->set_option(ip::v6_only{false}, err); if (err) { std::cout << "[Warning - ServicePort::open] Enabling IPv4 support failed: " << err.message() @@ -211,7 +176,7 @@ void ServicePort::open(uint16_t port) acceptor->set_option(ip::tcp::no_delay{true}); accept(); - } catch (boost::system::system_error& e) { + } catch (system::system_error& e) { std::cout << "[ServicePort::open] Error: " << e.what() << std::endl; pendingStart = true; @@ -224,17 +189,18 @@ void ServicePort::open(uint16_t port) void ServicePort::close() { if (acceptor && acceptor->is_open()) { - boost::system::error_code error; + system::error_code error; acceptor->close(error); } } -bool ServicePort::add_service(const Service_ptr& new_svc) +bool ServicePort::add_service(const Service_ptr& service) { - if (std::any_of(services.begin(), services.end(), [](const Service_ptr& svc) { return svc->is_single_socket(); })) { + if (std::any_of(services.begin(), services.end(), + [](const Service_ptr& service) { return service->is_single_socket(); })) { return false; } - services.push_back(new_svc); + services.push_back(service); return true; } diff --git a/src/server.h b/src/server.h index 5e55b39ac4..df01d968cd 100644 --- a/src/server.h +++ b/src/server.h @@ -7,6 +7,9 @@ #include "connection.h" #include "signals.h" +namespace asio = boost::asio; +namespace system = boost::system; + class ServiceBase { public: @@ -35,7 +38,7 @@ class Service final : public ServiceBase class ServicePort : public std::enable_shared_from_this { public: - explicit ServicePort(boost::asio::io_context& io_context) : io_context(io_context) {} + explicit ServicePort(asio::io_context& io_context) : io_context(io_context) {} ~ServicePort(); // non-copyable @@ -51,13 +54,13 @@ class ServicePort : public std::enable_shared_from_this Protocol_ptr make_protocol(NetworkMessage& msg, const Connection_ptr& connection) const; void onStopServer(); - void onAccept(Connection_ptr connection, const boost::system::error_code& error); + void onAccept(Connection_ptr connection, const system::error_code& error); private: void accept(); - boost::asio::io_context& io_context; - std::unique_ptr acceptor; + asio::io_context& io_context; + std::unique_ptr acceptor; std::vector services; uint16_t serverPort = 0; @@ -87,9 +90,9 @@ class ServiceManager std::unordered_map acceptors; - boost::asio::io_context io_context; + asio::io_context io_context; Signals signals{io_context}; - boost::asio::steady_timer death_timer{io_context}; + asio::steady_timer death_timer{io_context}; bool running = false; };