Skip to content

Commit

Permalink
[Performance] Rewrite the thread_safe cache. (#40)
Browse files Browse the repository at this point in the history
1. Add a cpp test on cache efficiency and correctness
2. Refactor the thread_safe_cache.

In the new thread_safe_cache, we maintain only 1 unified mapping instead
of 2. In addition, in our implementation, we acquire the unique_lock
less, which may bring a significant improvement in efficiency in some
cases.

Here's the result for `TEST(XGrammarParallelTest, CacheEfficiency)` on
my machine `AMD EPYC 7R32`.

End to end time (preprocessing time `2s` not included):

- Old cache: `1152ms`
- New cache: `217ms`

gcc version: `13.2.0`
  • Loading branch information
DarkSharpness authored Dec 6, 2024
1 parent fcca8f3 commit 634b66f
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 50 deletions.
87 changes: 37 additions & 50 deletions cpp/support/thread_safe_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#define XGRAMMAR_SUPPORT_THREAD_SAFE_CACHE_H_

#include <functional>
#include <memory>
#include <mutex>
#include <optional>
#include <shared_mutex>
Expand Down Expand Up @@ -112,41 +111,41 @@ class ThreadSafeCache<Key, Value> {
* \return The cached or newly computed value of the key
*/
Value Get(const Key& key) {
// Get or create the per-key mutex
std::shared_ptr<std::shared_mutex> key_mutex = GetOrCreateMutex(key);
// Why we need this:
// - When adding new elements to a unordered_map, the map may be rehashed,
// - which means all the iterators may be invalidated.
// - However, cppreference says:
// - "References and pointers to either key or data stored in the container are only invalidated
// - by erasing that element, even when the corresponding iterator is invalidated."
// - (See https://en.cppreference.com/w/cpp/container/unordered_map)
// - Therefore, we should maintain 2 locks.
// - When we add something to the cache, we should hold the cache_mutex_.
// - When we erase something from the cache, we should hold the clear_mutex_.

auto erase_lock = std::shared_lock(erase_mutex_);

// First attempt to read from cache_
{
std::shared_lock<std::shared_mutex> cache_lock(cache_mutex_);
auto cache_lock = std::shared_lock(cache_mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second; // Cache hit
}
}

// Acquire unique lock on the per-key mutex to compute the value
std::unique_lock<std::shared_mutex> key_lock(*key_mutex);
if (it != cache_.end()) { // Cache hit
auto& entry = it->second; // The iterator is invalidated after releasing the lock
cache_lock.unlock(); // Therefore, we should hold the entry by reference first

// Double-checked locking
{
std::shared_lock<std::shared_mutex> cache_lock(cache_mutex_);
auto it = cache_.find(key);
if (it != cache_.end()) {
return it->second;
// We should not hold lock here, since this function may be blocking.
return entry.get(compute_, key);
}
}

// Compute the value without holding the cache lock
Value value = compute_(key);

// Insert the value into cache_
// Acquire exclusive lock to compute value
{
std::unique_lock<std::shared_mutex> cache_lock(cache_mutex_);
XGRAMMAR_DCHECK(cache_.find(key) == cache_.end());
cache_[key] = value;
}
auto cache_lock = std::unique_lock(cache_mutex_);
auto& entry = cache_[key]; // Create a new entry
cache_lock.unlock(); // Release the lock before blocking

return value;
// We should not hold lock here, since this function may be blocking.
return entry.get(compute_, key);
}
}

/*!
Expand All @@ -155,41 +154,29 @@ class ThreadSafeCache<Key, Value> {
* them.
*/
void Clear() {
// Acquire locks in the order: global_key_mutex_ -> cache_mutex_
std::unique_lock<std::mutex> global_key_lock(global_key_mutex_);
std::unique_lock<std::shared_mutex> cache_lock(cache_mutex_);
auto erase_lock = std::unique_lock(erase_mutex_);
cache_.clear();
key_mutexes_.clear();
}

private:
/*!
* \brief Gets or creates a mutex for the given key
* \param key The key to get/create a mutex for
* \return A shared pointer to the mutex for this key
*/
std::shared_ptr<std::shared_mutex> GetOrCreateMutex(const Key& key) {
std::unique_lock<std::mutex> lock(global_key_mutex_);
auto it = key_mutexes_.find(key);
if (it == key_mutexes_.end()) {
auto new_mutex = std::make_shared<std::shared_mutex>();
XGRAMMAR_DCHECK(key_mutexes_.find(key) == key_mutexes_.end());
key_mutexes_[key] = new_mutex;
return new_mutex;
struct Entry {
Value value;
std::once_flag flag;
auto get(const std::function<Value(const Key&)>& f, const Key& key) -> const Value& {
// block in this lambda until the value is computed
std::call_once(flag, [&] { value = f(key); });
return value;
}
return it->second;
}
};

/*! \brief The cache mapping keys to computed values */
std::unordered_map<Key, Value> cache_;
std::unordered_map<Key, Entry> cache_;
/*! \brief The function used to compute values for uncached keys */
std::function<Value(const Key&)> compute_;
/*! \brief Per-key mutexes to allow parallel computation of different keys */
std::unordered_map<Key, std::shared_ptr<std::shared_mutex>> key_mutexes_;
/*! \brief Mutex protecting access to key_mutexes_ */
std::mutex global_key_mutex_;
/*! \brief Reader-writer lock protecting access to cache_ */
std::shared_mutex cache_mutex_;
/*! \brief Mutex protecting removing elements */
std::shared_mutex erase_mutex_;
};

} // namespace xgrammar
Expand Down
133 changes: 133 additions & 0 deletions tests/cpp/test_parallel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include <gtest/gtest.h>
#include <xgrammar/xgrammar.h>

#include <atomic>
#include <chrono>
#include <cstddef>
#include <future>
#include <string>
#include <thread>
#include <unordered_set>
#include <vector>

#include "support/logging.h"
#include "support/thread_safe_cache.h"

using namespace xgrammar;

static std::atomic_size_t counter{0};

static_assert(
sizeof(CompiledGrammar) >= sizeof(std::size_t),
"Our test requires that CompiledGrammar is at least as large as std::size_t"
);

// simulate a CompiledGrammar object
struct Grammar {
std::size_t uuid;
std::byte padding[sizeof(CompiledGrammar) - sizeof(std::size_t)];
};

using namespace std::chrono_literals;

TEST(XGrammarParallelTest, CacheEfficiency) {
auto cache = ThreadSafeCache<std::string, Grammar>{[](const std::string&) {
std::this_thread::sleep_for(1s); // simulate a slow operation
return Grammar{.uuid = counter++, .padding = {}};
}};
auto futures = std::vector<std::future<std::size_t>>{};

static const auto kGroups = 20;
static const auto kNumThreads = int(std::thread::hardware_concurrency()) * 2;
static const auto kNumTests = kNumThreads / 2;

futures.reserve(kNumThreads);
const auto target = std::chrono::steady_clock::now() + 1s;

// Whatever the execution order, the cache will only call the constructor for kNumTests times.
// As a consequence, the sum of the uuids must be equal to the sum of the first kNumTests
// integers.

const auto tic = std::chrono::high_resolution_clock::now();
for (auto i = 0; i < kNumThreads; ++i) {
futures.push_back(std::async(std::launch::async, [&cache, target, i] {
std::this_thread::sleep_until(target);
auto sum = std::size_t{0};
// Test writing to the cache concurrently
for (auto j = 0; j < kNumTests; ++j) {
const auto key = std::to_string((j + i) % kNumTests);
sum += cache.Get(key).uuid;
}
// Test reading the same keys again
for (auto j = 0; j < kNumTests * (kGroups - 1); ++j) {
const auto key = std::to_string(j % kNumTests);
sum += cache.Get(key).uuid;
}
return sum;
}));
}

// Sum of [0, kNumTests) (I wish i'm not wrong)
const auto kResult = kNumTests * (kNumTests - 1) / 2;

for (auto& future : futures) {
future.wait();
EXPECT_EQ(future.get(), kResult * kGroups);
}
const auto toc = std::chrono::high_resolution_clock::now();
// Skip the first 2s for preparation
const auto dur = std::chrono::duration_cast<std::chrono::milliseconds>(toc - tic - 2s).count();
XGRAMMAR_LOG_INFO << "Duration: " << dur << "ms";
}

struct LifeSpanHook {
inline static std::unordered_set<const void*> manager{};
LifeSpanHook() { manager.insert(this); }
LifeSpanHook(const LifeSpanHook&) { manager.insert(this); }
auto operator=(const LifeSpanHook& other) -> LifeSpanHook& {
this->check();
other.check();
return *this;
}
~LifeSpanHook() { EXPECT_TRUE(manager.erase(this)); }
auto check() const -> void { EXPECT_TRUE(manager.find(this) != manager.end()); }
};

struct TestObject : LifeSpanHook {
std::string name;
TestObject() = default;
TestObject(std::string name) : name(std::move(name)) {}
auto& operator=(std::string name) {
this->name = std::move(name);
return *this;
}
operator std::string() const {
this->check();
return this->name;
}
};

TEST(XGrammarParallelTest, CacheCorrectness) {
auto cache = ThreadSafeCache<std::string, TestObject>{[](const std::string& key) {
std::this_thread::sleep_for(1s); // simulate a slow operation
return key;
}};

auto futures = std::vector<std::future<std::string>>{};
futures.reserve(20);

for (auto i = 0; i < 20; ++i) {
futures.push_back(std::async(std::launch::async, [&cache, i] {
return std::string(cache.Get(std::to_string(i)));
}));
}

// Wait the futures to block
std::this_thread::sleep_for(100ms);

cache.Clear();

for (auto i = 0; i < 20; ++i) {
EXPECT_EQ(futures[i].get(), std::to_string(i));
}
}

0 comments on commit 634b66f

Please sign in to comment.