From 634b66f4878dcf1de323162bea9d978d413a3c6c Mon Sep 17 00:00:00 2001 From: DarkSharpness <76582120+DarkSharpness@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:04:16 +0900 Subject: [PATCH] [Performance] Rewrite the thread_safe cache. (#40) 1. Add a cpp test on cache efficiency and correctness 2. Refactor the thread_safe_cache. In the new thread_safe_cache, we maintain only 1 unified mapping instead of 2. In addition, in our implementation, we acquire the unique_lock less, which may bring a significant improvement in efficiency in some cases. Here's the result for `TEST(XGrammarParallelTest, CacheEfficiency)` on my machine `AMD EPYC 7R32`. End to end time (preprocessing time `2s` not included): - Old cache: `1152ms` - New cache: `217ms` gcc version: `13.2.0` --- cpp/support/thread_safe_cache.h | 87 +++++++++------------ tests/cpp/test_parallel.cc | 133 ++++++++++++++++++++++++++++++++ 2 files changed, 170 insertions(+), 50 deletions(-) create mode 100644 tests/cpp/test_parallel.cc diff --git a/cpp/support/thread_safe_cache.h b/cpp/support/thread_safe_cache.h index 46af105..617af72 100644 --- a/cpp/support/thread_safe_cache.h +++ b/cpp/support/thread_safe_cache.h @@ -7,7 +7,6 @@ #define XGRAMMAR_SUPPORT_THREAD_SAFE_CACHE_H_ #include -#include #include #include #include @@ -112,41 +111,41 @@ class ThreadSafeCache { * \return The cached or newly computed value of the key */ Value Get(const Key& key) { - // Get or create the per-key mutex - std::shared_ptr key_mutex = GetOrCreateMutex(key); + // Why we need this: + // - When adding new elements to a unordered_map, the map may be rehashed, + // - which means all the iterators may be invalidated. + // - However, cppreference says: + // - "References and pointers to either key or data stored in the container are only invalidated + // - by erasing that element, even when the corresponding iterator is invalidated." + // - (See https://en.cppreference.com/w/cpp/container/unordered_map) + // - Therefore, we should maintain 2 locks. + // - When we add something to the cache, we should hold the cache_mutex_. + // - When we erase something from the cache, we should hold the clear_mutex_. + + auto erase_lock = std::shared_lock(erase_mutex_); // First attempt to read from cache_ { - std::shared_lock cache_lock(cache_mutex_); + auto cache_lock = std::shared_lock(cache_mutex_); auto it = cache_.find(key); - if (it != cache_.end()) { - return it->second; // Cache hit - } - } - - // Acquire unique lock on the per-key mutex to compute the value - std::unique_lock key_lock(*key_mutex); + if (it != cache_.end()) { // Cache hit + auto& entry = it->second; // The iterator is invalidated after releasing the lock + cache_lock.unlock(); // Therefore, we should hold the entry by reference first - // Double-checked locking - { - std::shared_lock cache_lock(cache_mutex_); - auto it = cache_.find(key); - if (it != cache_.end()) { - return it->second; + // We should not hold lock here, since this function may be blocking. + return entry.get(compute_, key); } } - // Compute the value without holding the cache lock - Value value = compute_(key); - - // Insert the value into cache_ + // Acquire exclusive lock to compute value { - std::unique_lock cache_lock(cache_mutex_); - XGRAMMAR_DCHECK(cache_.find(key) == cache_.end()); - cache_[key] = value; - } + auto cache_lock = std::unique_lock(cache_mutex_); + auto& entry = cache_[key]; // Create a new entry + cache_lock.unlock(); // Release the lock before blocking - return value; + // We should not hold lock here, since this function may be blocking. + return entry.get(compute_, key); + } } /*! @@ -155,41 +154,29 @@ class ThreadSafeCache { * them. */ void Clear() { - // Acquire locks in the order: global_key_mutex_ -> cache_mutex_ - std::unique_lock global_key_lock(global_key_mutex_); - std::unique_lock cache_lock(cache_mutex_); + auto erase_lock = std::unique_lock(erase_mutex_); cache_.clear(); - key_mutexes_.clear(); } private: - /*! - * \brief Gets or creates a mutex for the given key - * \param key The key to get/create a mutex for - * \return A shared pointer to the mutex for this key - */ - std::shared_ptr GetOrCreateMutex(const Key& key) { - std::unique_lock lock(global_key_mutex_); - auto it = key_mutexes_.find(key); - if (it == key_mutexes_.end()) { - auto new_mutex = std::make_shared(); - XGRAMMAR_DCHECK(key_mutexes_.find(key) == key_mutexes_.end()); - key_mutexes_[key] = new_mutex; - return new_mutex; + struct Entry { + Value value; + std::once_flag flag; + auto get(const std::function& f, const Key& key) -> const Value& { + // block in this lambda until the value is computed + std::call_once(flag, [&] { value = f(key); }); + return value; } - return it->second; - } + }; /*! \brief The cache mapping keys to computed values */ - std::unordered_map cache_; + std::unordered_map cache_; /*! \brief The function used to compute values for uncached keys */ std::function compute_; - /*! \brief Per-key mutexes to allow parallel computation of different keys */ - std::unordered_map> key_mutexes_; - /*! \brief Mutex protecting access to key_mutexes_ */ - std::mutex global_key_mutex_; /*! \brief Reader-writer lock protecting access to cache_ */ std::shared_mutex cache_mutex_; + /*! \brief Mutex protecting removing elements */ + std::shared_mutex erase_mutex_; }; } // namespace xgrammar diff --git a/tests/cpp/test_parallel.cc b/tests/cpp/test_parallel.cc new file mode 100644 index 0000000..62a73e9 --- /dev/null +++ b/tests/cpp/test_parallel.cc @@ -0,0 +1,133 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "support/logging.h" +#include "support/thread_safe_cache.h" + +using namespace xgrammar; + +static std::atomic_size_t counter{0}; + +static_assert( + sizeof(CompiledGrammar) >= sizeof(std::size_t), + "Our test requires that CompiledGrammar is at least as large as std::size_t" +); + +// simulate a CompiledGrammar object +struct Grammar { + std::size_t uuid; + std::byte padding[sizeof(CompiledGrammar) - sizeof(std::size_t)]; +}; + +using namespace std::chrono_literals; + +TEST(XGrammarParallelTest, CacheEfficiency) { + auto cache = ThreadSafeCache{[](const std::string&) { + std::this_thread::sleep_for(1s); // simulate a slow operation + return Grammar{.uuid = counter++, .padding = {}}; + }}; + auto futures = std::vector>{}; + + static const auto kGroups = 20; + static const auto kNumThreads = int(std::thread::hardware_concurrency()) * 2; + static const auto kNumTests = kNumThreads / 2; + + futures.reserve(kNumThreads); + const auto target = std::chrono::steady_clock::now() + 1s; + + // Whatever the execution order, the cache will only call the constructor for kNumTests times. + // As a consequence, the sum of the uuids must be equal to the sum of the first kNumTests + // integers. + + const auto tic = std::chrono::high_resolution_clock::now(); + for (auto i = 0; i < kNumThreads; ++i) { + futures.push_back(std::async(std::launch::async, [&cache, target, i] { + std::this_thread::sleep_until(target); + auto sum = std::size_t{0}; + // Test writing to the cache concurrently + for (auto j = 0; j < kNumTests; ++j) { + const auto key = std::to_string((j + i) % kNumTests); + sum += cache.Get(key).uuid; + } + // Test reading the same keys again + for (auto j = 0; j < kNumTests * (kGroups - 1); ++j) { + const auto key = std::to_string(j % kNumTests); + sum += cache.Get(key).uuid; + } + return sum; + })); + } + + // Sum of [0, kNumTests) (I wish i'm not wrong) + const auto kResult = kNumTests * (kNumTests - 1) / 2; + + for (auto& future : futures) { + future.wait(); + EXPECT_EQ(future.get(), kResult * kGroups); + } + const auto toc = std::chrono::high_resolution_clock::now(); + // Skip the first 2s for preparation + const auto dur = std::chrono::duration_cast(toc - tic - 2s).count(); + XGRAMMAR_LOG_INFO << "Duration: " << dur << "ms"; +} + +struct LifeSpanHook { + inline static std::unordered_set manager{}; + LifeSpanHook() { manager.insert(this); } + LifeSpanHook(const LifeSpanHook&) { manager.insert(this); } + auto operator=(const LifeSpanHook& other) -> LifeSpanHook& { + this->check(); + other.check(); + return *this; + } + ~LifeSpanHook() { EXPECT_TRUE(manager.erase(this)); } + auto check() const -> void { EXPECT_TRUE(manager.find(this) != manager.end()); } +}; + +struct TestObject : LifeSpanHook { + std::string name; + TestObject() = default; + TestObject(std::string name) : name(std::move(name)) {} + auto& operator=(std::string name) { + this->name = std::move(name); + return *this; + } + operator std::string() const { + this->check(); + return this->name; + } +}; + +TEST(XGrammarParallelTest, CacheCorrectness) { + auto cache = ThreadSafeCache{[](const std::string& key) { + std::this_thread::sleep_for(1s); // simulate a slow operation + return key; + }}; + + auto futures = std::vector>{}; + futures.reserve(20); + + for (auto i = 0; i < 20; ++i) { + futures.push_back(std::async(std::launch::async, [&cache, i] { + return std::string(cache.Get(std::to_string(i))); + })); + } + + // Wait the futures to block + std::this_thread::sleep_for(100ms); + + cache.Clear(); + + for (auto i = 0; i < 20; ++i) { + EXPECT_EQ(futures[i].get(), std::to_string(i)); + } +}