-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add celerity blockchain for task divergence checking #217
base: master
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#pragma once | ||
|
||
#include "types.h" | ||
|
||
namespace celerity::detail { | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/* | ||
* @brief Defines an interface for a communicator that can be used to communicate between nodes. | ||
* | ||
* This interface is used to abstract away the communication between nodes. This allows us to use different communication backends during testing and | ||
* runtime. For example, we can use MPI for the runtime and a custom implementation for testing. | ||
*/ | ||
class communicator { | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
public: | ||
communicator() = default; | ||
communicator(const communicator&) = delete; | ||
communicator(communicator&&) noexcept = default; | ||
|
||
communicator& operator=(const communicator&) = delete; | ||
communicator& operator=(communicator&&) noexcept = default; | ||
|
||
virtual ~communicator() = default; | ||
|
||
template <typename S> | ||
void allgather_inplace(S* sendrecvbuf, const int sendrecvcount) { | ||
allgather_inplace_impl(reinterpret_cast<std::byte*>(sendrecvbuf), sendrecvcount * sizeof(S)); | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
template <typename S, typename R> | ||
void allgather(const S* sendbuf, const int sendcount, R* recvbuf, const int recvcount) { | ||
allgather_impl(reinterpret_cast<const std::byte*>(sendbuf), sendcount * sizeof(S), reinterpret_cast<std::byte*>(recvbuf), recvcount * sizeof(R)); | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
void barrier() { barrier_impl(); } | ||
|
||
size_t get_num_nodes() { return num_nodes_impl(); } | ||
|
||
node_id get_local_nid() { return local_nid_impl(); } | ||
|
||
protected: | ||
virtual void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) = 0; | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
virtual void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) = 0; | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
virtual void barrier_impl() = 0; | ||
virtual size_t num_nodes_impl() = 0; | ||
virtual node_id local_nid_impl() = 0; | ||
}; | ||
} // namespace celerity::detail |
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,137 @@ | ||||||||||
#pragma once | ||||||||||
|
||||||||||
#include <mutex> | ||||||||||
#include <thread> | ||||||||||
#include <vector> | ||||||||||
|
||||||||||
#include "communicator.h" | ||||||||||
#include "recorders.h" | ||||||||||
|
||||||||||
namespace celerity::detail::divergence_checker_detail { | ||||||||||
using task_hash = size_t; | ||||||||||
using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>; | ||||||||||
|
||||||||||
/** | ||||||||||
* @brief Stores the hashes of tasks for each node. | ||||||||||
* | ||||||||||
* The data is stored densely so it can easily be exchanged through MPI collective operations. | ||||||||||
*/ | ||||||||||
struct per_node_task_hashes { | ||||||||||
public: | ||||||||||
per_node_task_hashes(const size_t max_hash_count, const size_t num_nodes) : m_data(max_hash_count * num_nodes), m_max_hash_count(max_hash_count){}; | ||||||||||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
const task_hash& operator()(const node_id nid, const size_t i) const { return m_data.at(nid * m_max_hash_count + i); } | ||||||||||
task_hash* data() { return m_data.data(); } | ||||||||||
|
||||||||||
private: | ||||||||||
std::vector<task_hash> m_data; | ||||||||||
size_t m_max_hash_count; | ||||||||||
}; | ||||||||||
|
||||||||||
/** | ||||||||||
* @brief This class checks for divergences of tasks between nodes. | ||||||||||
* | ||||||||||
* It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence. | ||||||||||
* When a divergence is found, the task record for the diverging task is printed and the program is terminated. | ||||||||||
* Additionally it will also print a warning when a deadlock is suspected. | ||||||||||
*/ | ||||||||||
|
||||||||||
class divergence_block_chain { | ||||||||||
friend struct divergence_block_chain_testspy; | ||||||||||
|
||||||||||
public: | ||||||||||
divergence_block_chain(task_recorder& task_recorder, std::unique_ptr<communicator> comm) | ||||||||||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
: m_local_nid(comm->get_local_nid()), m_num_nodes(comm->get_num_nodes()), m_per_node_hash_counts(comm->get_num_nodes()), | ||||||||||
m_communicator(std::move(comm)) { | ||||||||||
task_recorder.add_callback([this](const task_record& task) { add_new_task(task); }); | ||||||||||
} | ||||||||||
|
||||||||||
divergence_block_chain(const divergence_block_chain&) = delete; | ||||||||||
divergence_block_chain(divergence_block_chain&&) = delete; | ||||||||||
|
||||||||||
~divergence_block_chain() = default; | ||||||||||
|
||||||||||
divergence_block_chain& operator=(const divergence_block_chain&) = delete; | ||||||||||
divergence_block_chain& operator=(divergence_block_chain&&) = delete; | ||||||||||
|
||||||||||
bool check_for_divergence(); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Needs a comment on what a |
||||||||||
|
||||||||||
private: | ||||||||||
node_id m_local_nid; | ||||||||||
size_t m_num_nodes; | ||||||||||
|
||||||||||
std::vector<task_hash> m_local_hashes; | ||||||||||
std::vector<task_record> m_task_records; | ||||||||||
size_t m_tasks_checked = 0; | ||||||||||
size_t m_hashes_added = 0; | ||||||||||
task_hash m_last_hash = 0; | ||||||||||
|
||||||||||
std::vector<int> m_per_node_hash_counts; | ||||||||||
std::mutex m_task_records_mutex; | ||||||||||
|
||||||||||
std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now(); | ||||||||||
std::chrono::seconds m_time_of_last_warning = std::chrono::seconds(0); | ||||||||||
|
||||||||||
std::unique_ptr<communicator> m_communicator; | ||||||||||
|
||||||||||
void reprot_divergence(const divergence_map& check_map, const int task_num); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
void add_new_hashes(); | ||||||||||
void clear(const int min_progress); | ||||||||||
std::pair<int, int> collect_hash_counts(); | ||||||||||
per_node_task_hashes collect_hashes(const int min_hash_count) const; | ||||||||||
divergence_map create_divergence_map(const per_node_task_hashes& task_hashes, const int task_num) const; | ||||||||||
|
||||||||||
void check_for_deadlock(); | ||||||||||
|
||||||||||
static void log_node_divergences(const divergence_map& check_map, const int task_id); | ||||||||||
static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash); | ||||||||||
void log_task_record_once(const divergence_map& check_map, const int task_num); | ||||||||||
|
||||||||||
void add_new_task(const task_record& task); | ||||||||||
task_record thread_save_get_task_record(const size_t task_num); | ||||||||||
}; | ||||||||||
}; // namespace celerity::detail::divergence_checker_detail | ||||||||||
|
||||||||||
namespace celerity::detail { | ||||||||||
class divergence_checker { | ||||||||||
friend struct runtime_testspy; | ||||||||||
|
||||||||||
public: | ||||||||||
divergence_checker(task_recorder& task_recorder, std::unique_ptr<communicator> comm, bool test_mode = false) | ||||||||||
: m_block_chain(task_recorder, std::move(comm)) { | ||||||||||
if(!test_mode) { start(); } | ||||||||||
} | ||||||||||
|
||||||||||
divergence_checker(const divergence_checker&) = delete; | ||||||||||
divergence_checker(const divergence_checker&&) = delete; | ||||||||||
|
||||||||||
divergence_checker& operator=(const divergence_checker&) = delete; | ||||||||||
divergence_checker& operator=(divergence_checker&&) = delete; | ||||||||||
|
||||||||||
~divergence_checker() { stop(); } | ||||||||||
|
||||||||||
private: | ||||||||||
std::thread m_thread; | ||||||||||
bool m_is_running = false; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or just use an atomic. |
||||||||||
divergence_checker_detail::divergence_block_chain m_block_chain; | ||||||||||
|
||||||||||
void start() { | ||||||||||
m_thread = std::thread(&divergence_checker::run, this); | ||||||||||
m_is_running = true; | ||||||||||
Comment on lines
+119
to
+120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like there is a race between setting
Suggested change
|
||||||||||
} | ||||||||||
|
||||||||||
void stop() { | ||||||||||
m_is_running = false; | ||||||||||
if(m_thread.joinable()) { m_thread.join(); } | ||||||||||
} | ||||||||||
|
||||||||||
void run() { | ||||||||||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
bool is_finished = false; | ||||||||||
while(!is_finished || m_is_running) { | ||||||||||
is_finished = m_block_chain.check_for_divergence(); | ||||||||||
|
||||||||||
std::this_thread::sleep_for(std::chrono::milliseconds(100)); | ||||||||||
} | ||||||||||
} | ||||||||||
}; | ||||||||||
}; // namespace celerity::detail |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#pragma once | ||
|
||
#include <memory> | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include <mpi.h> | ||
|
||
#include "communicator.h" | ||
|
||
namespace celerity::detail { | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class mpi_communicator : public communicator { | ||
public: | ||
mpi_communicator(MPI_Comm comm) : m_comm(comm) {} | ||
|
||
private: | ||
MPI_Comm m_comm; | ||
|
||
void allgather_inplace_impl(std::byte* sendrecvbuf, const int sendrecvcount) override { | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, sendrecvbuf, sendrecvcount, MPI_BYTE, m_comm); | ||
}; | ||
|
||
void allgather_impl(const std::byte* sendbuf, const int sendcount, std::byte* recvbuf, const int recvcount) override { | ||
psalz marked this conversation as resolved.
Show resolved
Hide resolved
psalz marked this conversation as resolved.
Show resolved
Hide resolved
|
||
MPI_Allgather(sendbuf, sendcount, MPI_BYTE, recvbuf, recvcount, MPI_BYTE, m_comm); | ||
}; | ||
|
||
void barrier_impl() override { MPI_Barrier(m_comm); } | ||
|
||
size_t num_nodes_impl() override { | ||
int size = -1; | ||
MPI_Comm_size(m_comm, &size); | ||
return static_cast<size_t>(size); | ||
} | ||
|
||
node_id local_nid_impl() override { | ||
int rank = -1; | ||
MPI_Comm_rank(m_comm, &rank); | ||
return static_cast<node_id>(rank); | ||
} | ||
}; | ||
} // namespace celerity::detail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also needs to be added to
cmake/celerity-config.cmake.in
!