From ac9aa92ce3a8bf63618fcf9f0b53f04cd74f02ad Mon Sep 17 00:00:00 2001 From: Fabian Sauter Date: Sun, 26 Jan 2025 15:15:12 +0100 Subject: [PATCH] Thread pool2 tests and implementation finishing touches --- cpr/threadpool2.cpp | 37 +++++++++++++++++++++++++++++-------- include/cpr/threadpool2.h | 37 ++++++++++++++++++++++++++++++++++++- test/threadpool2_tests.cpp | 36 ++++++++++++++++++++++++++++++++++-- 3 files changed, 99 insertions(+), 11 deletions(-) diff --git a/cpr/threadpool2.cpp b/cpr/threadpool2.cpp index bf03c6cc3..18430f274 100644 --- a/cpr/threadpool2.cpp +++ b/cpr/threadpool2.cpp @@ -1,12 +1,13 @@ #include "cpr/threadpool2.h" #include -#include +#include #include -#include +#include #include +#include #include -#include #include +#include namespace cpr { size_t ThreadPool2::DEFAULT_MAX_THREAD_COUNT = std::thread::hardware_concurrency(); @@ -49,15 +50,12 @@ void ThreadPool2::SetMaxThreadCount(size_t maxThreadCount) { void ThreadPool2::Start() { const std::unique_lock lock(controlMutex); setState(State::RUNNING); - - for (size_t i = 0; i < minThreadCount; i++) { - addThread(); - } } void ThreadPool2::Stop() { const std::unique_lock controlLock(controlMutex); setState(State::STOP); + taskQueueCondVar.notify_all(); // Join all workers const std::unique_lock workersLock{workerMutex}; @@ -70,11 +68,21 @@ void ThreadPool2::Stop() { } } +void ThreadPool2::Wait() { + while (true) { + if ((state != State::RUNNING && curThreadCount <= 0) || (tasks.empty() && curThreadCount <= idleThreadCount)) { + break; + } + std::this_thread::yield(); + } +} + void ThreadPool2::setState(State state) { const std::unique_lock lock(controlMutex); if (this->state == state) { return; } + this->state = state; } void ThreadPool2::addThread() { @@ -84,6 +92,7 @@ void ThreadPool2::addThread() { workers.emplace_back(); workers.back().thread = std::make_unique(&ThreadPool2::threadFunc, this, std::ref(workers.back())); curThreadCount++; + idleThreadCount++; } void ThreadPool2::threadFunc(WorkerThread& workerThread) { @@ -91,7 +100,9 @@ void ThreadPool2::threadFunc(WorkerThread& workerThread) { std::cv_status result{std::cv_status::timeout}; { std::unique_lock lock(taskQueueMutex); - result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250)); + if (tasks.empty()) { + result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250)); + } } if (state == State::STOP) { @@ -109,6 +120,16 @@ void ThreadPool2::threadFunc(WorkerThread& workerThread) { } // Check for tasks and execute one + const std::unique_lock lock(taskQueueMutex); + if (!tasks.empty()) { + idleThreadCount--; + const std::function task = std::move(tasks.front()); + tasks.pop(); + + // Execute the task + task(); + } + idleThreadCount++; } workerThread.state = State::STOP; diff --git a/include/cpr/threadpool2.h b/include/cpr/threadpool2.h index a3d2831e0..b42735a74 100644 --- a/include/cpr/threadpool2.h +++ b/include/cpr/threadpool2.h @@ -4,9 +4,12 @@ #include #include #include +#include +#include #include #include #include +#include #include namespace cpr { @@ -16,7 +19,7 @@ class ThreadPool2 { static size_t DEFAULT_MAX_THREAD_COUNT; private: - enum class State : uint8_t { STOP, RUNNING, PAUSE }; + enum class State : uint8_t { STOP, RUNNING }; struct WorkerThread { std::unique_ptr thread{nullptr}; State state{State::RUNNING}; @@ -28,11 +31,13 @@ class ThreadPool2 { std::mutex taskQueueMutex; std::condition_variable taskQueueCondVar; + std::queue> tasks; std::atomic state = State::STOP; std::atomic_size_t minThreadCount; std::atomic_size_t curThreadCount{0}; std::atomic_size_t maxThreadCount; + std::atomic_size_t idleThreadCount{0}; std::recursive_mutex controlMutex; @@ -55,6 +60,36 @@ class ThreadPool2 { void Start(); void Stop(); + void Wait(); + + /** + * Return a future, calling future.get() will wait task done and return RetType. + * Submit(fn, args...) + * Submit(std::bind(&Class::mem_fn, &obj)) + * Submit(std::mem_fn(&Class::mem_fn, &obj)) + **/ + template + auto Submit(Fn&& fn, Args&&... args) { + // Add a new worker thread in case the tasks queue is not empty and we still can add a thread + { + std::unique_lock lock(taskQueueMutex); + if (idleThreadCount < tasks.size() && curThreadCount < maxThreadCount) { + addThread(); + } + } + + // Add task to queue + using RetType = decltype(fn(args...)); + const std::shared_ptr> task = std::make_shared>([fn = std::forward(fn), args...]() mutable { return std::invoke(fn, args...); }); + std::future future = task->get_future(); + { + std::unique_lock lock(taskQueueMutex); + tasks.emplace([task] { (*task)(); }); + } + + taskQueueCondVar.notify_one(); + return future; + } private: void setState(State newState); diff --git a/test/threadpool2_tests.cpp b/test/threadpool2_tests.cpp index 35810b1fd..be57200ed 100644 --- a/test/threadpool2_tests.cpp +++ b/test/threadpool2_tests.cpp @@ -4,8 +4,40 @@ #include "cpr/threadpool2.h" -TEST(ThreadPool2Tests, StartStop) { - cpr::ThreadPool2 tp(1, 1); +TEST(ThreadPool2Tests, BasicWorkOneThread) { + std::atomic_uint32_t invCount{0}; + uint32_t invCountExpected{100}; + + { + cpr::ThreadPool2 tp(1, 1); + + for (size_t i = 0; i < invCountExpected; ++i) { + tp.Submit([&invCount]() -> void { invCount++; }); + } + + // Wait for the thread pool to finish its work + tp.Wait(); + } + + EXPECT_EQ(invCount, invCountExpected); +} + +TEST(ThreadPool2Tests, BasicWorkMultipleThreads) { + std::atomic_uint32_t invCount{0}; + uint32_t invCountExpected{100}; + + { + cpr::ThreadPool2 tp(1, 10); + + for (size_t i = 0; i < invCountExpected; ++i) { + tp.Submit([&invCount]() -> void { invCount++; }); + } + + // Wait for the thread pool to finish its work + tp.Wait(); + } + + EXPECT_EQ(invCount, invCountExpected); } int main(int argc, char** argv) {