diff --git a/src/vt/collective/reduce/allreduce/helpers.h b/src/vt/collective/reduce/allreduce/helpers.h index d168a590ff..a5bc8b4499 100644 --- a/src/vt/collective/reduce/allreduce/helpers.h +++ b/src/vt/collective/reduce/allreduce/helpers.h @@ -95,6 +95,11 @@ struct ShouldUseView> { }; #endif // MAGISTRATE_KOKKOS_ENABLED +inline NodeType getActualDest(NodeType vrt_node, uint32_t scatter_mask, NodeType nprocs_rem) { + auto vdest = static_cast(vrt_node ^ scatter_mask); + return (vdest < nprocs_rem) ? vdest * 2 : vdest + nprocs_rem; +} + // Helper alias for cleaner usage template inline constexpr bool ShouldUseView_v = ShouldUseView::Value; @@ -110,7 +115,7 @@ struct DataHelper { static auto createMessage( const std::vector& payload, size_t begin, size_t count, size_t id, - int32_t step = 0) { + uint32_t step = 0) { return vt::makeMessage>( payload.data() + begin, count, id, step); } @@ -156,7 +161,7 @@ struct DataHelper> { static auto createMessage( const DataT& payload, size_t begin, size_t count, size_t id, - int32_t step = 0) { + uint32_t step = 0) { return vt::makeMessage>( Kokkos::subview(payload, std::make_pair(begin, begin + count)), id, step ); diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.cc b/src/vt/collective/reduce/allreduce/rabenseifner.cc index af84f07e49..4f1a6dc9bb 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.cc +++ b/src/vt/collective/reduce/allreduce/rabenseifner.cc @@ -42,6 +42,7 @@ */ #include "vt/collective/reduce/allreduce/rabenseifner.h" +#include "vt/group/group_manager.h" #include "vt/configs/error/config_assert.h" namespace vt::collective::reduce::allreduce { @@ -53,7 +54,7 @@ Rabenseifner::Rabenseifner( nodes_(theGroup()->GetGroupNodes(group.get())), num_nodes_(nodes_.size()), this_node_(theContext()->getNode()), - num_steps_(static_cast(std::log2(num_nodes_))), + num_steps_(static_cast(std::log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { @@ -90,7 +91,7 @@ Rabenseifner::Rabenseifner( nodes_(theGroup()->GetGroupNodes(group.get())), num_nodes_(nodes_.size()), this_node_(theContext()->getNode()), - num_steps_(static_cast(log2(num_nodes_))), + num_steps_(static_cast(std::log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { std::string nodes_info; @@ -139,7 +140,7 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup) nodes_(theGroup()->GetGroupNodes(default_group)), num_nodes_(nodes_.size()), this_node_(theContext()->getNode()), - num_steps_(static_cast(log2(num_nodes_))), + num_steps_(static_cast(std::log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { std::string nodes_info; diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index f644c0407d..db3725667b 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -184,7 +184,7 @@ struct Rabenseifner { * \param step The current step in the scatter phase. */ template class Op> - void scatterTryReduce(size_t id, int32_t step); + void scatterTryReduce(size_t id, uint32_t step); /** * \brief Perform the scatter-reduce iteration. @@ -234,7 +234,7 @@ struct Rabenseifner { * \param step The current step in the gather phase. */ template - void gatherTryReduce(size_t id, int32_t step); + void gatherTryReduce(size_t id, uint32_t step); /** * \brief Perform the gather iteration. @@ -299,11 +299,11 @@ struct Rabenseifner { bool is_even_ = false; /// Num steps for each scatter/gather phase - int32_t num_steps_ = {}; + uint32_t num_steps_ = {}; /// 2^num_steps_ - int32_t nprocs_pof2_ = {}; - int32_t nprocs_rem_ = {}; + uint32_t nprocs_pof2_ = {}; + NodeType nprocs_rem_ = {}; /// For non-power-of-2 number of nodes this respresents whether current Node /// is excluded (has value of -1) from computation diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h index cb8f0b9c85..72d288121a 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h @@ -48,12 +48,6 @@ #include "vt/config.h" #include "vt/context/context.h" #include "vt/configs/error/config_assert.h" -#include "vt/group/global/group_default.h" -#include "vt/group/group_manager.h" -#include "vt/group/group_info.h" -#include "vt/configs/types/types_sentinels.h" -#include "vt/registry/auto/auto_registry.h" -#include "vt/utils/fntraits/fntraits.h" #include "vt/configs/debug/debug_print.h" #include "vt/configs/debug/debug_printconst.h" #include "vt/collective/reduce/allreduce/type.h" @@ -157,12 +151,11 @@ void Rabenseifner::initialize(size_t id, Args&&... data) { initializeState(id); } - int step = 0; + uint32_t step = 0; state.size_ = state.val_.size(); auto size = state.size_; - for (int mask = 1; mask < nprocs_pof2_; mask <<= 1) { - auto vdest = vrt_node_ ^ mask; - auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_; + for (uint32_t mask = 1; mask < nprocs_pof2_; mask <<= 1) { + auto const dest = getActualDest(vrt_node_, state.scatter_mask_, nprocs_rem_); if (this_node_ < dest) { state.r_count_[step] = size / 2; @@ -377,7 +370,7 @@ bool Rabenseifner::scatterIsReady(size_t id) { } template class Op> -void Rabenseifner::scatterTryReduce(size_t id, int32_t step) { +void Rabenseifner::scatterTryReduce(size_t id, uint32_t step) { using DataHelperT = DataHelper::Scalar, DataT>; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, id); @@ -420,8 +413,7 @@ void Rabenseifner::scatterReduceIter(size_t id) { state.scatter_step_, state.s_index_.size(), state.s_count_.size(), id, proxy_.getProxy(), print_ptr(&state)); - auto vdest = vrt_node_ ^ state.scatter_mask_; - auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_; + auto const dest = getActualDest(vrt_node_, state.scatter_mask_, nprocs_rem_); auto const actual_partner = nodes_[dest]; vt_debug_print( @@ -536,7 +528,7 @@ bool Rabenseifner::gatherIsReady(size_t id) { } template -void Rabenseifner::gatherTryReduce(size_t id, int32_t step) { +void Rabenseifner::gatherTryReduce(size_t id, uint32_t step) { using DataHelperT = DataHelper::Scalar, DataT>; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, id); @@ -565,8 +557,7 @@ void Rabenseifner::gatherIter(size_t id) { using DataHelperT = DataHelper; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, id); - auto vdest = vrt_node_ ^ state.gather_mask_; - auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_; + auto const dest = getActualDest(vrt_node_, state.scatter_mask_, nprocs_rem_); auto const actual_partner = nodes_[dest]; vt_debug_print( diff --git a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h index d04bfcfded..95c4296b69 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h @@ -65,7 +65,7 @@ struct RabenseifnerMsg : Message { } } - RabenseifnerMsg(const Scalar* in_val, size_t size, size_t id, int step = 0) + RabenseifnerMsg(const Scalar* in_val, size_t size, size_t id, uint32_t step = 0) : MessageParentType(), val_(in_val), size_(size), @@ -89,63 +89,10 @@ struct RabenseifnerMsg : Message { s | step_; } -struct NoCombine {}; - -template -struct IsTuple : std::false_type {}; -template -struct IsTuple> : std::true_type {}; - - template - static void combine(MsgT* m1, MsgT* m2) { - Op()(m1->getVal(), m2->getConstVal()); - } - - template - static void FinalHandler(ReduceTMsg* msg) { - // using MsgT = ReduceTMsg; - vt_debug_print( - terse, reduce, - "FinalHandler: reduce root: ptr={}\n", print_ptr(msg) - ); - // if (msg->isRoot()) { - // vt_debug_print( - // terse, reduce, - // "FinalHandler::ROOT: reduce root: ptr={}\n", print_ptr(msg) - // ); - // if (msg->hasValidCallback()) { - // envelopeUnlockForForwarding(msg->env); - // if (msg->isParamCallback()) { - // if constexpr (IsTuple::value) { - // msg->getParamCallback().sendTuple(std::move(msg->getVal())); - // } - // } else { - // // We need to force the type to the more specific one here - // auto cb = msg->getMsgCallback(); - // auto typed_cb = reinterpret_cast*>(&cb); - // typed_cb->sendMsg(msg); - // } - // } else if (msg->root_handler_ != uninitialized_handler) { - // auto_registry::getAutoHandler(msg->root_handler_)->dispatch(msg, nullptr); - // } - // } else { - // MsgT* fst_msg = msg; - // MsgT* cur_msg = msg->template getNext(); - // vt_debug_print( - // terse, reduce, - // "FinalHandler::leaf: fst ptr={}\n", print_ptr(fst_msg) - // ); - // while (cur_msg != nullptr) { - // RabenseifnerMsg::combine(fst_msg, cur_msg); - // cur_msg = cur_msg->template getNext(); - // } - // } - } - const Scalar* val_ = {}; size_t size_ = {}; size_t id_ = {}; - int32_t step_ = {}; + uint32_t step_ = {}; bool owning_ = false; }; @@ -160,7 +107,7 @@ struct RabenseifnerMsg> : Messa RabenseifnerMsg(RabenseifnerMsg const&) = default; RabenseifnerMsg(RabenseifnerMsg&&) = default; - RabenseifnerMsg(const ViewT& in_val, size_t id, int step = 0) + RabenseifnerMsg(const ViewT& in_val, size_t id, uint32_t step = 0) : MessageParentType(), val_(in_val), id_(id), @@ -177,7 +124,7 @@ struct RabenseifnerMsg> : Messa ViewT val_ = {}; size_t id_ = {}; - int32_t step_ = {}; + uint32_t step_ = {}; }; #endif // MAGISTRATE_KOKKOS_ENABLED diff --git a/src/vt/collective/reduce/allreduce/state.h b/src/vt/collective/reduce/allreduce/state.h index 975c6d7cc8..536c33ada7 100644 --- a/src/vt/collective/reduce/allreduce/state.h +++ b/src/vt/collective/reduce/allreduce/state.h @@ -64,18 +64,18 @@ struct StateBase { struct RabensiferBase : StateBase { // Scatter - int32_t scatter_mask_ = 1; - int32_t scatter_step_ = 0; - int32_t scatter_num_recv_ = 0; + uint32_t scatter_mask_ = 1; + uint32_t scatter_step_ = 0; + uint32_t scatter_num_recv_ = 0; std::vector scatter_steps_recv_ = {}; std::vector scatter_steps_reduced_ = {}; bool finished_scatter_part_ = false; // Gather - int32_t gather_step_ = 0; - int32_t gather_mask_ = 1; - int32_t gather_num_recv_ = 0; + uint32_t gather_step_ = 0; + uint32_t gather_mask_ = 1; + uint32_t gather_num_recv_ = 0; std::vector gather_steps_recv_ = {}; std::vector gather_steps_reduced_ = {};