Skip to content

Commit

Permalink
Thread pool2 tests and implementation finishing touches
Browse files Browse the repository at this point in the history
  • Loading branch information
COM8 committed Jan 26, 2025
1 parent ad2b5de commit ac9aa92
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 11 deletions.
37 changes: 29 additions & 8 deletions cpr/threadpool2.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "cpr/threadpool2.h"
#include <cassert>
#include <cstddef>
#include <chrono>
#include <condition_variable>
#include <memory>
#include <cstddef>
#include <functional>
#include <memory>
#include <mutex>
#include <chrono>
#include <thread>
#include <utility>

namespace cpr {
size_t ThreadPool2::DEFAULT_MAX_THREAD_COUNT = std::thread::hardware_concurrency();
Expand Down Expand Up @@ -49,15 +50,12 @@ void ThreadPool2::SetMaxThreadCount(size_t maxThreadCount) {
void ThreadPool2::Start() {
const std::unique_lock lock(controlMutex);
setState(State::RUNNING);

for (size_t i = 0; i < minThreadCount; i++) {
addThread();
}
}

void ThreadPool2::Stop() {
const std::unique_lock controlLock(controlMutex);
setState(State::STOP);
taskQueueCondVar.notify_all();

// Join all workers
const std::unique_lock workersLock{workerMutex};
Expand All @@ -70,11 +68,21 @@ void ThreadPool2::Stop() {
}
}

void ThreadPool2::Wait() {
while (true) {
if ((state != State::RUNNING && curThreadCount <= 0) || (tasks.empty() && curThreadCount <= idleThreadCount)) {
break;
}
std::this_thread::yield();
}
}

void ThreadPool2::setState(State state) {
const std::unique_lock lock(controlMutex);
if (this->state == state) {
return;
}
this->state = state;
}

void ThreadPool2::addThread() {
Expand All @@ -84,14 +92,17 @@ void ThreadPool2::addThread() {
workers.emplace_back();
workers.back().thread = std::make_unique<std::thread>(&ThreadPool2::threadFunc, this, std::ref(workers.back()));
curThreadCount++;
idleThreadCount++;
}

void ThreadPool2::threadFunc(WorkerThread& workerThread) {
while (true) {
std::cv_status result{std::cv_status::timeout};
{
std::unique_lock lock(taskQueueMutex);
result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250));
if (tasks.empty()) {
result = taskQueueCondVar.wait_for(lock, std::chrono::milliseconds(250));
}
}

if (state == State::STOP) {
Expand All @@ -109,6 +120,16 @@ void ThreadPool2::threadFunc(WorkerThread& workerThread) {
}

// Check for tasks and execute one
const std::unique_lock lock(taskQueueMutex);
if (!tasks.empty()) {
idleThreadCount--;
const std::function<void()> task = std::move(tasks.front());
tasks.pop();

// Execute the task
task();
}
idleThreadCount++;
}

workerThread.state = State::STOP;
Expand Down
37 changes: 36 additions & 1 deletion include/cpr/threadpool2.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#include <condition_variable>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <future>
#include <list>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>

namespace cpr {
Expand All @@ -16,7 +19,7 @@ class ThreadPool2 {
static size_t DEFAULT_MAX_THREAD_COUNT;

private:
enum class State : uint8_t { STOP, RUNNING, PAUSE };
enum class State : uint8_t { STOP, RUNNING };
struct WorkerThread {
std::unique_ptr<std::thread> thread{nullptr};
State state{State::RUNNING};
Expand All @@ -28,11 +31,13 @@ class ThreadPool2 {

std::mutex taskQueueMutex;
std::condition_variable taskQueueCondVar;
std::queue<std::function<void()>> tasks;

std::atomic<State> state = State::STOP;
std::atomic_size_t minThreadCount;
std::atomic_size_t curThreadCount{0};
std::atomic_size_t maxThreadCount;
std::atomic_size_t idleThreadCount{0};

std::recursive_mutex controlMutex;

Expand All @@ -55,6 +60,36 @@ class ThreadPool2 {

void Start();
void Stop();
void Wait();

/**
* Return a future, calling future.get() will wait task done and return RetType.
* Submit(fn, args...)
* Submit(std::bind(&Class::mem_fn, &obj))
* Submit(std::mem_fn(&Class::mem_fn, &obj))
**/
template <class Fn, class... Args>
auto Submit(Fn&& fn, Args&&... args) {
// Add a new worker thread in case the tasks queue is not empty and we still can add a thread
{
std::unique_lock lock(taskQueueMutex);
if (idleThreadCount < tasks.size() && curThreadCount < maxThreadCount) {
addThread();
}
}

// Add task to queue
using RetType = decltype(fn(args...));
const std::shared_ptr<std::packaged_task<RetType()>> task = std::make_shared<std::packaged_task<RetType()>>([fn = std::forward<Fn>(fn), args...]() mutable { return std::invoke(fn, args...); });
std::future<RetType> future = task->get_future();
{
std::unique_lock lock(taskQueueMutex);
tasks.emplace([task] { (*task)(); });
}

taskQueueCondVar.notify_one();
return future;
}

private:
void setState(State newState);
Expand Down
36 changes: 34 additions & 2 deletions test/threadpool2_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,40 @@

#include "cpr/threadpool2.h"

TEST(ThreadPool2Tests, StartStop) {
cpr::ThreadPool2 tp(1, 1);
TEST(ThreadPool2Tests, BasicWorkOneThread) {
std::atomic_uint32_t invCount{0};
uint32_t invCountExpected{100};

{
cpr::ThreadPool2 tp(1, 1);

for (size_t i = 0; i < invCountExpected; ++i) {
tp.Submit([&invCount]() -> void { invCount++; });
}

// Wait for the thread pool to finish its work
tp.Wait();
}

EXPECT_EQ(invCount, invCountExpected);
}

TEST(ThreadPool2Tests, BasicWorkMultipleThreads) {
std::atomic_uint32_t invCount{0};
uint32_t invCountExpected{100};

{
cpr::ThreadPool2 tp(1, 10);

for (size_t i = 0; i < invCountExpected; ++i) {
tp.Submit([&invCount]() -> void { invCount++; });
}

// Wait for the thread pool to finish its work
tp.Wait();
}

EXPECT_EQ(invCount, invCountExpected);
}

int main(int argc, char** argv) {
Expand Down

0 comments on commit ac9aa92

Please sign in to comment.