From 554f36f072de63fd1fe86b5a1adefbf9e98349f9 Mon Sep 17 00:00:00 2001 From: hhvrc Date: Fri, 15 Nov 2024 15:04:15 +0100 Subject: [PATCH] Fix and abstract away HTTP RateLimiter --- include/RateLimiter.h | 38 ++++++++++ src/RateLimiter.cpp | 121 ++++++++++++++++++++++++++++++++ src/http/HTTPRequestManager.cpp | 116 +++--------------------------- 3 files changed, 170 insertions(+), 105 deletions(-) create mode 100644 include/RateLimiter.h create mode 100644 src/RateLimiter.cpp diff --git a/include/RateLimiter.h b/include/RateLimiter.h new file mode 100644 index 00000000..13654ae1 --- /dev/null +++ b/include/RateLimiter.h @@ -0,0 +1,38 @@ +#pragma once + +#include "Common.h" +#include "SimpleMutex.h" + +#include +#include + +namespace OpenShock { + class RateLimiter { + DISABLE_COPY(RateLimiter); + DISABLE_MOVE(RateLimiter); + + public: + RateLimiter(); + ~RateLimiter(); + + void addLimit(uint32_t durationMs, uint16_t count); + void clearLimits(); + + bool tryRequest(); + void clearRequests(); + + void blockFor(int64_t blockForMs); + + private: + struct Limit { + int64_t durationMs; + uint16_t count; + }; + + OpenShock::SimpleMutex m_mutex; + int64_t m_nextSlot; + int64_t m_nextCleanup; + std::vector m_limits; + std::vector m_requests; + }; +} // namespace OpenShock diff --git a/src/RateLimiter.cpp b/src/RateLimiter.cpp new file mode 100644 index 00000000..b53c4fad --- /dev/null +++ b/src/RateLimiter.cpp @@ -0,0 +1,121 @@ +#include + +#include "RateLimiter.h" + +#include "Time.h" + +#include + +const char* const TAG = "RateLimiter"; + +OpenShock::RateLimiter::RateLimiter() + : m_mutex() + , m_nextSlot(0) + , m_nextCleanup(0) + , m_limits() + , m_requests() +{ +} + +OpenShock::RateLimiter::~RateLimiter() +{ +} + +void OpenShock::RateLimiter::addLimit(uint32_t durationMs, uint16_t count) +{ + m_mutex.lock(portMAX_DELAY); + + // Insert sorted + m_limits.insert(std::upper_bound(m_limits.begin(), m_limits.end(), durationMs, [](int64_t durationMs, const Limit& limit) { return durationMs < limit.durationMs; }), {durationMs, count}); + + m_nextSlot = 0; + m_nextCleanup = 0; + + m_mutex.unlock(); +} + +void OpenShock::RateLimiter::clearLimits() +{ + m_mutex.lock(portMAX_DELAY); + + m_limits.clear(); + + m_mutex.unlock(); +} + +bool OpenShock::RateLimiter::tryRequest() +{ + int64_t now = OpenShock::millis(); + + OpenShock::ScopedLock lock__(&m_mutex); + + if (m_limits.empty()) { + return true; + } + if (m_requests.empty()) { + m_requests.push_back(now); + return true; + } + + if (m_nextCleanup <= now) { + int64_t longestLimit = m_limits.back().durationMs; + int64_t expiresAt = now - longestLimit; + + auto nextToExpire = std::find_if(m_requests.begin(), m_requests.end(), [expiresAt](int64_t requestedAtMs) { return requestedAtMs > expiresAt; }); + if (nextToExpire != m_requests.end()) { + m_requests.erase(m_requests.begin(), nextToExpire); + } + + m_nextCleanup = m_requests.front() + longestLimit; + } + + if (m_nextSlot > now) { + return false; + } + + // Check if we've exceeded any limits, starting with the highest limit first + for (std::size_t i = m_limits.size(); i > 0;) { + const auto& limit = m_limits[--i]; + + // Calculate the window start time + int64_t windowStart = now - limit.durationMs; + + // Check how many requests are inside the limit window + std::size_t insideWindow = 0; + for (int64_t request : m_requests) { + if (request > windowStart) { + insideWindow++; + } + } + + // If the window is full, set the wait time until its available, and reject the request + if (insideWindow >= limit.count) { + m_nextSlot = m_requests.back() + limit.durationMs; + return false; + } + } + + // Add the request + m_requests.push_back(now); + + return true; +} +void OpenShock::RateLimiter::clearRequests() +{ + m_mutex.lock(portMAX_DELAY); + + m_requests.clear(); + + m_mutex.unlock(); +} + +void OpenShock::RateLimiter::blockFor(int64_t blockForMs) +{ + int64_t blockUntil = OpenShock::millis() + blockForMs; + + m_mutex.lock(portMAX_DELAY); + + m_nextSlot = std::max(m_nextSlot, blockUntil); + + m_mutex.unlock(); +} diff --git a/src/http/HTTPRequestManager.cpp b/src/http/HTTPRequestManager.cpp index 9ed9d704..4804d53d 100644 --- a/src/http/HTTPRequestManager.cpp +++ b/src/http/HTTPRequestManager.cpp @@ -4,6 +4,7 @@ const char* const TAG = "HTTPRequestManager"; #include "Common.h" #include "Logging.h" +#include "RateLimiter.h" #include "SimpleMutex.h" #include "Time.h" #include "util/StringUtils.h" @@ -22,100 +23,8 @@ using namespace std::string_view_literals; const std::size_t HTTP_BUFFER_SIZE = 4096LLU; const int HTTP_DOWNLOAD_SIZE_LIMIT = 200 * 1024 * 1024; // 200 MB -struct RateLimit { - RateLimit() - : m_mutex() - , m_blockUntilMs(0) - , m_limits() - , m_requests() - { - } - - void addLimit(uint32_t durationMs, uint16_t count) - { - m_mutex.lock(portMAX_DELAY); - - // Insert sorted - m_limits.insert(std::upper_bound(m_limits.begin(), m_limits.end(), durationMs, [](int64_t durationMs, const Limit& limit) { return durationMs > limit.durationMs; }), {durationMs, count}); - - m_mutex.unlock(); - } - - void clearLimits() - { - m_mutex.lock(portMAX_DELAY); - - m_limits.clear(); - - m_mutex.unlock(); - } - - bool tryRequest() - { - int64_t now = OpenShock::millis(); - - OpenShock::ScopedLock lock__(&m_mutex); - - if (m_blockUntilMs > now) { - return false; - } - - // Remove all requests that are older than the biggest limit - while (!m_requests.empty() && m_requests.front() < now - m_limits.back().durationMs) { - m_requests.erase(m_requests.begin()); - } - - // Check if we've exceeded any limits - auto it = std::find_if(m_limits.begin(), m_limits.end(), [this](const RateLimit::Limit& limit) { return m_requests.size() >= limit.count; }); - if (it != m_limits.end()) { - m_blockUntilMs = now + it->durationMs; - return false; - } - - // Add the request - m_requests.push_back(now); - - return true; - } - void clearRequests() - { - m_mutex.lock(portMAX_DELAY); - - m_requests.clear(); - - m_mutex.unlock(); - } - - void blockUntil(int64_t blockUntilMs) - { - m_mutex.lock(portMAX_DELAY); - - m_blockUntilMs = blockUntilMs; - - m_mutex.unlock(); - } - - uint32_t requestsSince(int64_t sinceMs) - { - OpenShock::ScopedLock lock__(&m_mutex); - - return std::count_if(m_requests.begin(), m_requests.end(), [sinceMs](int64_t requestMs) { return requestMs >= sinceMs; }); - } - -private: - struct Limit { - int64_t durationMs; - uint16_t count; - }; - - OpenShock::SimpleMutex m_mutex; - int64_t m_blockUntilMs; - std::vector m_limits; - std::vector m_requests; -}; - -static OpenShock::SimpleMutex s_rateLimitsMutex = {}; -static std::unordered_map> s_rateLimits = {}; +static OpenShock::SimpleMutex s_rateLimitsMutex = {}; +static std::unordered_map> s_rateLimits = {}; using namespace OpenShock; @@ -156,9 +65,9 @@ std::string_view _getDomain(std::string_view url) return url; } -std::shared_ptr _rateLimitFactory(std::string_view domain) +std::shared_ptr _rateLimiterFactory(std::string_view domain) { - auto rateLimit = std::make_shared(); + auto rateLimit = std::make_shared(); // Add default limits rateLimit->addLimit(1000, 5); // 5 per second @@ -173,7 +82,7 @@ std::shared_ptr _rateLimitFactory(std::string_view domain) return rateLimit; } -std::shared_ptr _getRateLimiter(std::string_view url) +std::shared_ptr _getRateLimiter(std::string_view url) { auto domain = std::string(_getDomain(url)); if (domain.empty()) { @@ -184,7 +93,7 @@ std::shared_ptr _getRateLimiter(std::string_view url) auto it = s_rateLimits.find(domain); if (it == s_rateLimits.end()) { - s_rateLimits.emplace(domain, _rateLimitFactory(domain)); + s_rateLimits.emplace(domain, _rateLimiterFactory(domain)); it = s_rateLimits.find(domain); } @@ -469,7 +378,7 @@ HTTP::Response _doGetStream( std::string_view url, const std::map& headers, const std::vector& acceptedCodes, - std::shared_ptr rateLimiter, + std::shared_ptr rateLimiter, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, uint32_t timeoutMs @@ -509,11 +418,8 @@ HTTP::Response _doGetStream( retryAfter = 15; } - // Get the block-until time - int64_t blockUntilMs = OpenShock::millis() + retryAfter * 1000; - - // Apply the block-until time - rateLimiter->blockUntil(blockUntilMs); + // Apply the block-for time + rateLimiter->blockFor(retryAfter * 1000); return {HTTP::RequestResult::RateLimited, responseCode, 0}; } @@ -563,7 +469,7 @@ HTTP::Response _doGetStream( HTTP::Response HTTP::Download(std::string_view url, const std::map& headers, HTTP::GotContentLengthCallback contentLengthCallback, HTTP::DownloadCallback downloadCallback, const std::vector& acceptedCodes, uint32_t timeoutMs) { - std::shared_ptr rateLimiter = _getRateLimiter(url); + std::shared_ptr rateLimiter = _getRateLimiter(url); if (rateLimiter == nullptr) { return {RequestResult::InvalidURL, 0, 0}; }