Skip to content

Commit

Permalink
#2281: Add unit tests for new Collection allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 17, 2024
1 parent 1568c4a commit 74b9c06
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 96 deletions.
17 changes: 7 additions & 10 deletions src/vt/collective/reduce/allreduce/rabenseifner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ Rabenseifner::Rabenseifner(
print_ptr(this), proxy.get(), proxy_.getProxy(), local_num_elems_);
}

Rabenseifner::Rabenseifner(
detail::StrongGroup group)
Rabenseifner::Rabenseifner(detail::StrongGroup group)
: group_(group.get()),
local_num_elems_(1),
nodes_(theGroup()->GetGroupNodes(group.get())),
Expand Down Expand Up @@ -136,22 +135,20 @@ Rabenseifner::Rabenseifner(
Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup)
: objgroup_proxy_(objgroup.get()),
local_num_elems_(1),
nodes_(theGroup()->GetGroupNodes(default_group)),
num_nodes_(nodes_.size()),
num_nodes_(theContext()->getNumNodes()),
this_node_(theContext()->getNode()),
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_) {
std::string nodes_info;
for (auto& node : nodes_) {
nodes_info += fmt::format("{} ", node);
nodes_.resize(num_nodes_);
for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) {
nodes_[i] = i;
}

vt_debug_print(
terse, allreduce,
"Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} "
"Nodes:[{}]\n",
true, true, num_nodes_, nodes_info);
"Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n",
true, true, num_nodes_);

// We collectively create this Reducer, so it's possible that not all Nodes are part of it
is_even_ = this_node_ % 2 == 0;
Expand Down
4 changes: 0 additions & 4 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ struct ObjgroupAllreduceT {};
*
* This class performs an allreduce operation using Rabenseifner's method. The algorithm consists
* of several phases: adjustment for power-of-two processes, scatter-reduce, and gather-allgather.
*
* \tparam DataT Type of the data being reduced.
* \tparam Op Reduction operation (e.g., sum, max, min).
* \tparam finalHandler Callback handler for the final result.
*/

struct Rabenseifner {
Expand Down
4 changes: 0 additions & 4 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ void Rabenseifner::localReduce(size_t id, Args&&... data) {

state.local_col_wait_count_++;
auto const is_ready = state.local_col_wait_count_ == local_num_elems_;
// vt_debug_print(
// terse, allreduce, "Rabenseifner (this={}): local_col_wait_count_={} ID={} is_ready={}\n",
// print_ptr(this), state.local_col_wait_count_, id, is_ready
// );

if (is_ready) {
// Execute early in case we're the only node
Expand Down
53 changes: 0 additions & 53 deletions src/vt/collective/reduce/allreduce/rabenseifner_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,59 +89,6 @@ struct RabenseifnerMsg : Message {
s | step_;
}

struct NoCombine {};

template <typename>
struct IsTuple : std::false_type {};
template <typename... Args>
struct IsTuple<std::tuple<Args...>> : std::true_type {};

template <typename MsgT, typename Op, typename ActOp>
static void combine(MsgT* m1, MsgT* m2) {
Op()(m1->getVal(), m2->getConstVal());
}

template <typename Tuple, typename Op, typename ActOp>
static void FinalHandler(ReduceTMsg<Tuple>* msg) {
// using MsgT = ReduceTMsg<Tuple>;
vt_debug_print(
terse, reduce,
"FinalHandler: reduce root: ptr={}\n", print_ptr(msg)
);
// if (msg->isRoot()) {
// vt_debug_print(
// terse, reduce,
// "FinalHandler::ROOT: reduce root: ptr={}\n", print_ptr(msg)
// );
// if (msg->hasValidCallback()) {
// envelopeUnlockForForwarding(msg->env);
// if (msg->isParamCallback()) {
// if constexpr (IsTuple<typename MsgT::DataT>::value) {
// msg->getParamCallback().sendTuple(std::move(msg->getVal()));
// }
// } else {
// // We need to force the type to the more specific one here
// auto cb = msg->getMsgCallback();
// auto typed_cb = reinterpret_cast<Callback<MsgT>*>(&cb);
// typed_cb->sendMsg(msg);
// }
// } else if (msg->root_handler_ != uninitialized_handler) {
// auto_registry::getAutoHandler(msg->root_handler_)->dispatch(msg, nullptr);
// }
// } else {
// MsgT* fst_msg = msg;
// MsgT* cur_msg = msg->template getNext<MsgT>();
// vt_debug_print(
// terse, reduce,
// "FinalHandler::leaf: fst ptr={}\n", print_ptr(fst_msg)
// );
// while (cur_msg != nullptr) {
// RabenseifnerMsg<Scalar, DataT>::combine<MsgT,Op,ActOp>(fst_msg, cur_msg);
// cur_msg = cur_msg->template getNext<MsgT>();
// }
// }
}

const Scalar* val_ = {};
size_t size_ = {};
size_t id_ = {};
Expand Down
2 changes: 1 addition & 1 deletion src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//@HEADER
// *****************************************************************************
//
// recursive_doubling.impl.h
// recursive_doubling.cc
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
Expand Down
11 changes: 3 additions & 8 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ namespace vt::collective::reduce::allreduce {
* This class provides an implementation of the Recursive Doubling algorithm for the
* allreduce operation. It is parameterized by the data type to be reduced, the reduction
* operation, the object type, and the final handler.
*
* \tparam DataT The data type to be reduced.
* \tparam Op The reduction operation type.
* \tparam ObjT The object type.
* \tparam finalHandler The final handler.
*/

struct RecursiveDoubling {
Expand Down Expand Up @@ -139,7 +134,7 @@ struct RecursiveDoubling {
* \param msg Pointer to the message.
*/
template <typename DataT, template <typename Arg> class Op>
void adjustForPowerOfTwoHandler(AllreduceDblRawMsg<DataT>* msg);
void adjustForPowerOfTwoHandler(RecursiveDoublingMsg<DataT>* msg);

/**
* \brief Check if the allreduce operation is done.
Expand Down Expand Up @@ -193,7 +188,7 @@ struct RecursiveDoubling {
* \param msg Pointer to the message.
*/
template <typename DataT, template <typename Arg> class Op>
void reduceIterHandler(AllreduceDblRawMsg<DataT>* msg);
void reduceIterHandler(RecursiveDoublingMsg<DataT>* msg);

/**
* \brief Send data to excluded nodes for finalization.
Expand All @@ -207,7 +202,7 @@ struct RecursiveDoubling {
* \param msg Pointer to the message.
*/
template <typename DataT>
void sendToExcludedNodesHandler(AllreduceDblRawMsg<DataT>* msg);
void sendToExcludedNodesHandler(RecursiveDoublingMsg<DataT>* msg);

/**
* \brief Perform the final part of the allreduce operation.
Expand Down
13 changes: 9 additions & 4 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,12 @@ void RecursiveDoubling::localReduce(size_t id, Args&&... data) {
auto const is_ready = state.local_col_wait_count_ == local_num_elems_;

if (is_ready) {
allreduce<DataT, Op>(id);
// Execute early in case we're the only node
if (num_nodes_ < 2) {
executeFinalHan<DataT>(id);
} else {
allreduce<DataT, Op>(id);
}
}
}

Expand Down Expand Up @@ -151,7 +156,7 @@ void RecursiveDoubling::adjustForPowerOfTwo(size_t id) {

template <typename DataT, template <typename Arg> class Op>
void RecursiveDoubling::adjustForPowerOfTwoHandler(
AllreduceDblRawMsg<DataT>* msg) {
RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(
collection_proxy_, objgroup_proxy_, group_, msg->id_);
Expand Down Expand Up @@ -261,7 +266,7 @@ void RecursiveDoubling::tryReduce(size_t id, int32_t step) {
}

template <typename DataT, template <typename Arg> class Op>
void RecursiveDoubling::reduceIterHandler(AllreduceDblRawMsg<DataT>* msg) {
void RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(
collection_proxy_, objgroup_proxy_, group_, msg->id_);
Expand Down Expand Up @@ -319,7 +324,7 @@ void RecursiveDoubling::sendToExcludedNodes(size_t id) {

template <typename DataT>
void RecursiveDoubling::sendToExcludedNodesHandler(
AllreduceDblRawMsg<DataT>* msg) {
RecursiveDoublingMsg<DataT>* msg) {
executeFinalHan<DataT>(msg->id_);
}

Expand Down
12 changes: 6 additions & 6 deletions src/vt/collective/reduce/allreduce/recursive_doubling_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,20 @@
namespace vt::collective::reduce::allreduce {

template <typename DataT>
struct AllreduceDblRawMsg : Message {
struct RecursiveDoublingMsg : Message {
using MessageParentType = vt::Message;
vt_msg_serialize_required();

AllreduceDblRawMsg() = default;
AllreduceDblRawMsg(AllreduceDblRawMsg const&) = default;
AllreduceDblRawMsg(AllreduceDblRawMsg&&) = default;
~AllreduceDblRawMsg() {
RecursiveDoublingMsg() = default;
RecursiveDoublingMsg(RecursiveDoublingMsg const&) = default;
RecursiveDoublingMsg(RecursiveDoublingMsg&&) = default;
~RecursiveDoublingMsg() {
if (owning_) {
delete val_;
}
}

AllreduceDblRawMsg(DataT const& in_val, size_t id, int step = 0)
RecursiveDoublingMsg(DataT const& in_val, size_t id, int step = 0)
: MessageParentType(),
val_(&in_val),
id_(id),
Expand Down
4 changes: 2 additions & 2 deletions src/vt/collective/reduce/allreduce/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ template <typename DataT>
struct RecursiveDoublingState : StateBase {
DataT val_ = {};
bool value_assigned_ = false;
MsgSharedPtr<AllreduceDblRawMsg<DataT>> adjust_message_ = nullptr;
MsgSharedPtr<RecursiveDoublingMsg<DataT>> adjust_message_ = nullptr;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceDblRawMsg<DataT>>> messages_ = {};
std::vector<MsgSharedPtr<RecursiveDoublingMsg<DataT>>> messages_ = {};
vt::pipe::callback::cbunion::CallbackTyped<DataT> final_handler_ = {};
};

Expand Down
2 changes: 1 addition & 1 deletion src/vt/group/group_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ bool GroupManager::inGroup(GroupType const group) {

std::vector<NodeType>
GroupManager::GetGroupNodes(GroupType const group_id) const {
if (group_id == default_group) {
if (isGroupDefault(group_id)) {
std::vector<NodeType> nodes(theContext()->getNumNodes());
for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) {
nodes[i] = i;
Expand Down
3 changes: 0 additions & 3 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@
//@HEADER
*/

#include "vt/collective/reduce/allreduce/recursive_doubling.h"
#include "vt/collective/reduce/allreduce/type.h"
#include <type_traits>
#if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H
#define INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H

Expand Down
Loading

0 comments on commit 74b9c06

Please sign in to comment.