Skip to content

Latest commit

 

History

History
319 lines (270 loc) · 6.81 KB

0.手写叠加线程池的DAG.md

File metadata and controls

319 lines (270 loc) · 6.81 KB
#include <condition_variable>
#include <deque>
#include <functional>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <vector>

using namespace std;

class ThreadPool {
 public:
  ThreadPool(int size = 1) {
    for (int i = 0; i < size; i++) {
      AddThread();
    }
  }

  ~ThreadPool() {
    {
      unique_lock<mutex> lck(mtx_);
      is_stop_ = true;
    }

    cv_.notify_all();

    for (auto& tid : pool_)
      if (tid.joinable()) tid.join();
  }

  void AddThread() {
    pool_.emplace_back([this]() {
      while (true) {
        function<void()> func;
        {
          unique_lock<mutex> lck(mtx_);
          cv_.wait(lck, [this]() { return !task_.empty() || is_stop_; });

          if (is_stop_ && task_.empty()) return;

          func = move(task_.front());
          task_.pop_front();
        }
        func();
      }
    });
  }

  template <typename F, typename... Args>
  void PutTask(F&& f, Args&&... args) {
    auto func = bind(forward<F>(f), forward<Args>(args)...);
    {
      unique_lock<mutex> lck(mtx_);
      task_.emplace_back(func);
    }
    cv_.notify_one();
  }

 private:
  deque<function<void()>> task_;
  vector<thread> pool_;
  int size_{0};
  bool is_stop_{false};
  condition_variable cv_;
  mutex mtx_;
};

mutex g_log_mtx;

class LogQueue {
 public:
  void Push(string&& str) { log_que_.emplace(str); }

  void Cout() {
    lock_guard<mutex> log_lock(g_log_mtx);
    if (!log_que_.empty()) {
      string log = move(log_que_.front());
      log_que_.pop();
      cout << log << endl;
    } else {
      cout << "logque is empty" << endl;
    }
  }

 private:
  queue<string> log_que_{};
};

LogQueue g_log_que;

class Module {
 public:
  Module(string name, vector<string> deps) : name_(name), deps_(deps) {}
  virtual void Execute() = 0;
  const string& Name() const { return name_; }
  const vector<string>& Deps() const { return deps_; }
  bool CheckSucc() { return is_succ_; }
  void SetSucc() { is_succ_ = true; }
  void ClearState() {
    is_succ_ = false;
    if (commit_log_thread_->joinable()) commit_log_thread_->join();
  }

 protected:
  unique_ptr<thread> commit_log_thread_;

 private:
  string name_;
  vector<string> deps_;
  bool is_succ_{false};
};

class ModuleA : public Module {
 public:
  ModuleA(string name, vector<string> deps) : Module(name, deps) {}
  void Execute() override {
    // compute
    this_thread::sleep_for(std::chrono::milliseconds(1000));
    SetSucc();

    // commit log
    AsyncCommitLog();
  }

  void AsyncCommitLog() {
    lock_guard<mutex> log_lock(g_log_mtx);
    commit_log_thread_ = make_unique<thread>([this] {
      stringstream ss;
      ss << "Executing: " << Name() << " Tid: " << this_thread::get_id()
         << endl;
      g_log_que.Push(ss.str());
    });
  }
};

class ModuleB : public Module {
 public:
  ModuleB(string name, vector<string> deps) : Module(name, deps) {}
  void Execute() override {
    // compute
    this_thread::sleep_for(std::chrono::milliseconds(1000));
    SetSucc();

    // commit log
    AsyncCommitLog();
  }

  void AsyncCommitLog() {
    lock_guard<mutex> log_lock(g_log_mtx);
    commit_log_thread_ = make_unique<thread>([this] {
      stringstream ss;
      ss << "Executing: " << Name() << " Tid: " << this_thread::get_id()
         << endl;
      g_log_que.Push(ss.str());
    });
  }
};

class ModuleC : public Module {
 public:
  ModuleC(string name, vector<string> deps) : Module(name, deps) {}
  void Execute() override {
    // compute
    this_thread::sleep_for(std::chrono::milliseconds(1000));
    SetSucc();

    // commit log
    AsyncCommitLog();
  }

  void AsyncCommitLog() {
    lock_guard<mutex> log_lock(g_log_mtx);
    commit_log_thread_ = make_unique<thread>([this] {
      stringstream ss;
      ss << "Executing: " << Name() << " Tid: " << this_thread::get_id()
         << endl;
      g_log_que.Push(ss.str());
    });
  }
};

class ModuleD : public Module {
 public:
  ModuleD(string name, vector<string> deps) : Module(name, deps) {}
  void Execute() override {
    // compute
    this_thread::sleep_for(std::chrono::milliseconds(1000));
    SetSucc();

    // commit log
    AsyncCommitLog();
  }

  void AsyncCommitLog() {
    lock_guard<mutex> log_lock(g_log_mtx);
    commit_log_thread_ = make_unique<thread>([this] {
      stringstream ss;
      ss << "Executing: " << Name() << " Tid: " << this_thread::get_id()
         << endl;
      g_log_que.Push(ss.str());
    });
  }
};

class ModuleE : public Module {
 public:
  ModuleE(string name, vector<string> deps) : Module(name, deps) {}
  void Execute() override {
    // compute
    this_thread::sleep_for(std::chrono::milliseconds(1000));
    SetSucc();

    // commit log
    AsyncCommitLog();
  }

  void AsyncCommitLog() {
    lock_guard<mutex> log_lock(g_log_mtx);
    commit_log_thread_ = make_unique<thread>([this] {
      stringstream ss;
      ss << "Executing: " << Name() << " Tid: " << this_thread::get_id()
         << endl;
      g_log_que.Push(ss.str());
    });
  }

 private:
  bool is_end_{false};
};

class Executor {
 public:
  void AddModule(Module* mod) { modules_[mod->Name()] = mod; }

  void ExecuteAll(ThreadPool& tp) {
    unordered_map<string, bool> visited;
    for (auto& mod_pair : modules_) {
      if (!visited[mod_pair.first]) {
        Execute(mod_pair.second, tp, visited);
      }
    }
  }

 private:
  void Execute(Module* mod, ThreadPool& tp,
               unordered_map<string, bool>& visited) {
    visited[mod->Name()] = true;
    for (auto& dep : mod->Deps()) {
      if (!visited[dep]) {
        Execute(modules_[dep], tp, visited);
      }
    }

    while (true) {
      bool is_deps_suss{true};
      for (auto& dep : mod->Deps()) {
        is_deps_suss = is_deps_suss && modules_[dep]->CheckSucc();
      }
      if (is_deps_suss) break;
    }

    tp.PutTask([mod]() { mod->Execute(); });
  }

  unordered_map<string, Module*> modules_;
};

int main() {
  ModuleA a("A", {});
  ModuleB b("B", {});
  ModuleC c("C", {"A", "B"});
  ModuleD d("D", {"C"});
  ModuleE e("E", {"C", "D"});

  Executor executor;
  executor.AddModule(&a);
  executor.AddModule(&b);
  executor.AddModule(&c);
  executor.AddModule(&d);
  executor.AddModule(&e);

  bool is_program_finish{false};

  thread log_que_thread([&g_log_que, &is_program_finish] {
    while (!is_program_finish) {
      g_log_que.Cout();
      this_thread::sleep_for(std::chrono::milliseconds(1000));
    }
  });

  const int num_threads = 2;
  ThreadPool tp(num_threads);

  int graph_times = 2;
  while (graph_times--) {
    executor.ExecuteAll(tp);
    this_thread::sleep_for(std::chrono::milliseconds(3000));

    a.ClearState();
    b.ClearState();
    c.ClearState();
    d.ClearState();
    e.ClearState();
  }

  {
    is_program_finish = true;
    if (log_que_thread.joinable()) log_que_thread.join();
  }

  return 0;
}