From 74b9c0697981d95dccb27eb48ea05ef845dffa0a Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Tue, 17 Sep 2024 17:24:12 +0200 Subject: [PATCH] #2281: Add unit tests for new Collection allreduce --- .../reduce/allreduce/rabenseifner.cc | 17 +- .../reduce/allreduce/rabenseifner.h | 4 - .../reduce/allreduce/rabenseifner.impl.h | 4 - .../reduce/allreduce/rabenseifner_msg.h | 53 ------ .../reduce/allreduce/recursive_doubling.cc | 2 +- .../reduce/allreduce/recursive_doubling.h | 11 +- .../allreduce/recursive_doubling.impl.h | 13 +- .../reduce/allreduce/recursive_doubling_msg.h | 12 +- src/vt/collective/reduce/allreduce/state.h | 4 +- src/vt/group/group_manager.cc | 2 +- src/vt/vrt/collection/manager.impl.h | 3 - .../collection/test_allreduce_collection.cc | 171 ++++++++++++++++++ 12 files changed, 200 insertions(+), 96 deletions(-) create mode 100644 tests/unit/collection/test_allreduce_collection.cc diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.cc b/src/vt/collective/reduce/allreduce/rabenseifner.cc index af84f07e49..b4680d76cb 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.cc +++ b/src/vt/collective/reduce/allreduce/rabenseifner.cc @@ -83,8 +83,7 @@ Rabenseifner::Rabenseifner( print_ptr(this), proxy.get(), proxy_.getProxy(), local_num_elems_); } -Rabenseifner::Rabenseifner( - detail::StrongGroup group) +Rabenseifner::Rabenseifner(detail::StrongGroup group) : group_(group.get()), local_num_elems_(1), nodes_(theGroup()->GetGroupNodes(group.get())), @@ -136,22 +135,20 @@ Rabenseifner::Rabenseifner( Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup) : objgroup_proxy_(objgroup.get()), local_num_elems_(1), - nodes_(theGroup()->GetGroupNodes(default_group)), - num_nodes_(nodes_.size()), + num_nodes_(theContext()->getNumNodes()), this_node_(theContext()->getNode()), num_steps_(static_cast(log2(num_nodes_))), nprocs_pof2_(1 << num_steps_), nprocs_rem_(num_nodes_ - nprocs_pof2_) { - std::string nodes_info; - for (auto& node : nodes_) { - nodes_info += fmt::format("{} ", node); + nodes_.resize(num_nodes_); + for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) { + nodes_[i] = i; } vt_debug_print( terse, allreduce, - "Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} " - "Nodes:[{}]\n", - true, true, num_nodes_, nodes_info); + "Rabenseifner: is_default_group={} is_part_of_allreduce={} num_nodes_={} \n", + true, true, num_nodes_); // We collectively create this Reducer, so it's possible that not all Nodes are part of it is_even_ = this_node_ % 2 == 0; diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index f644c0407d..1735ff7e43 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -71,10 +71,6 @@ struct ObjgroupAllreduceT {}; * * This class performs an allreduce operation using Rabenseifner's method. The algorithm consists * of several phases: adjustment for power-of-two processes, scatter-reduce, and gather-allgather. - * - * \tparam DataT Type of the data being reduced. - * \tparam Op Reduction operation (e.g., sum, max, min). - * \tparam finalHandler Callback handler for the final result. */ struct Rabenseifner { diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h index cb8f0b9c85..1f765b423b 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h @@ -94,10 +94,6 @@ void Rabenseifner::localReduce(size_t id, Args&&... data) { state.local_col_wait_count_++; auto const is_ready = state.local_col_wait_count_ == local_num_elems_; - // vt_debug_print( - // terse, allreduce, "Rabenseifner (this={}): local_col_wait_count_={} ID={} is_ready={}\n", - // print_ptr(this), state.local_col_wait_count_, id, is_ready - // ); if (is_ready) { // Execute early in case we're the only node diff --git a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h index d04bfcfded..99f58498b5 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner_msg.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner_msg.h @@ -89,59 +89,6 @@ struct RabenseifnerMsg : Message { s | step_; } -struct NoCombine {}; - -template -struct IsTuple : std::false_type {}; -template -struct IsTuple> : std::true_type {}; - - template - static void combine(MsgT* m1, MsgT* m2) { - Op()(m1->getVal(), m2->getConstVal()); - } - - template - static void FinalHandler(ReduceTMsg* msg) { - // using MsgT = ReduceTMsg; - vt_debug_print( - terse, reduce, - "FinalHandler: reduce root: ptr={}\n", print_ptr(msg) - ); - // if (msg->isRoot()) { - // vt_debug_print( - // terse, reduce, - // "FinalHandler::ROOT: reduce root: ptr={}\n", print_ptr(msg) - // ); - // if (msg->hasValidCallback()) { - // envelopeUnlockForForwarding(msg->env); - // if (msg->isParamCallback()) { - // if constexpr (IsTuple::value) { - // msg->getParamCallback().sendTuple(std::move(msg->getVal())); - // } - // } else { - // // We need to force the type to the more specific one here - // auto cb = msg->getMsgCallback(); - // auto typed_cb = reinterpret_cast*>(&cb); - // typed_cb->sendMsg(msg); - // } - // } else if (msg->root_handler_ != uninitialized_handler) { - // auto_registry::getAutoHandler(msg->root_handler_)->dispatch(msg, nullptr); - // } - // } else { - // MsgT* fst_msg = msg; - // MsgT* cur_msg = msg->template getNext(); - // vt_debug_print( - // terse, reduce, - // "FinalHandler::leaf: fst ptr={}\n", print_ptr(fst_msg) - // ); - // while (cur_msg != nullptr) { - // RabenseifnerMsg::combine(fst_msg, cur_msg); - // cur_msg = cur_msg->template getNext(); - // } - // } - } - const Scalar* val_ = {}; size_t size_ = {}; size_t id_ = {}; diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.cc b/src/vt/collective/reduce/allreduce/recursive_doubling.cc index 14e621d336..9bb7f358ed 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.cc +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.cc @@ -2,7 +2,7 @@ //@HEADER // ***************************************************************************** // -// recursive_doubling.impl.h +// recursive_doubling.cc // DARMA/vt => Virtual Transport // // Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.h b/src/vt/collective/reduce/allreduce/recursive_doubling.h index 9ac0b51648..6fefe12029 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.h @@ -68,11 +68,6 @@ namespace vt::collective::reduce::allreduce { * This class provides an implementation of the Recursive Doubling algorithm for the * allreduce operation. It is parameterized by the data type to be reduced, the reduction * operation, the object type, and the final handler. - * - * \tparam DataT The data type to be reduced. - * \tparam Op The reduction operation type. - * \tparam ObjT The object type. - * \tparam finalHandler The final handler. */ struct RecursiveDoubling { @@ -139,7 +134,7 @@ struct RecursiveDoubling { * \param msg Pointer to the message. */ template class Op> - void adjustForPowerOfTwoHandler(AllreduceDblRawMsg* msg); + void adjustForPowerOfTwoHandler(RecursiveDoublingMsg* msg); /** * \brief Check if the allreduce operation is done. @@ -193,7 +188,7 @@ struct RecursiveDoubling { * \param msg Pointer to the message. */ template class Op> - void reduceIterHandler(AllreduceDblRawMsg* msg); + void reduceIterHandler(RecursiveDoublingMsg* msg); /** * \brief Send data to excluded nodes for finalization. @@ -207,7 +202,7 @@ struct RecursiveDoubling { * \param msg Pointer to the message. */ template - void sendToExcludedNodesHandler(AllreduceDblRawMsg* msg); + void sendToExcludedNodesHandler(RecursiveDoublingMsg* msg); /** * \brief Perform the final part of the allreduce operation. diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h index acba945cdc..3b30ed2d58 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h @@ -80,7 +80,12 @@ void RecursiveDoubling::localReduce(size_t id, Args&&... data) { auto const is_ready = state.local_col_wait_count_ == local_num_elems_; if (is_ready) { - allreduce(id); + // Execute early in case we're the only node + if (num_nodes_ < 2) { + executeFinalHan(id); + } else { + allreduce(id); + } } } @@ -151,7 +156,7 @@ void RecursiveDoubling::adjustForPowerOfTwo(size_t id) { template class Op> void RecursiveDoubling::adjustForPowerOfTwoHandler( - AllreduceDblRawMsg* msg) { + RecursiveDoublingMsg* msg) { using DataType = DataHandler; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, msg->id_); @@ -261,7 +266,7 @@ void RecursiveDoubling::tryReduce(size_t id, int32_t step) { } template class Op> -void RecursiveDoubling::reduceIterHandler(AllreduceDblRawMsg* msg) { +void RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg* msg) { using DataType = DataHandler; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, msg->id_); @@ -319,7 +324,7 @@ void RecursiveDoubling::sendToExcludedNodes(size_t id) { template void RecursiveDoubling::sendToExcludedNodesHandler( - AllreduceDblRawMsg* msg) { + RecursiveDoublingMsg* msg) { executeFinalHan(msg->id_); } diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling_msg.h b/src/vt/collective/reduce/allreduce/recursive_doubling_msg.h index 8005db9e3b..a1b9dbc319 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling_msg.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling_msg.h @@ -49,20 +49,20 @@ namespace vt::collective::reduce::allreduce { template -struct AllreduceDblRawMsg : Message { +struct RecursiveDoublingMsg : Message { using MessageParentType = vt::Message; vt_msg_serialize_required(); - AllreduceDblRawMsg() = default; - AllreduceDblRawMsg(AllreduceDblRawMsg const&) = default; - AllreduceDblRawMsg(AllreduceDblRawMsg&&) = default; - ~AllreduceDblRawMsg() { + RecursiveDoublingMsg() = default; + RecursiveDoublingMsg(RecursiveDoublingMsg const&) = default; + RecursiveDoublingMsg(RecursiveDoublingMsg&&) = default; + ~RecursiveDoublingMsg() { if (owning_) { delete val_; } } - AllreduceDblRawMsg(DataT const& in_val, size_t id, int step = 0) + RecursiveDoublingMsg(DataT const& in_val, size_t id, int step = 0) : MessageParentType(), val_(&in_val), id_(id), diff --git a/src/vt/collective/reduce/allreduce/state.h b/src/vt/collective/reduce/allreduce/state.h index 975c6d7cc8..8d68d004d5 100644 --- a/src/vt/collective/reduce/allreduce/state.h +++ b/src/vt/collective/reduce/allreduce/state.h @@ -89,11 +89,11 @@ template struct RecursiveDoublingState : StateBase { DataT val_ = {}; bool value_assigned_ = false; - MsgSharedPtr> adjust_message_ = nullptr; + MsgSharedPtr> adjust_message_ = nullptr; std::vector steps_recv_ = {}; std::vector steps_reduced_ = {}; - std::vector>> messages_ = {}; + std::vector>> messages_ = {}; vt::pipe::callback::cbunion::CallbackTyped final_handler_ = {}; }; diff --git a/src/vt/group/group_manager.cc b/src/vt/group/group_manager.cc index bd0ade9a3e..2b40562fe0 100644 --- a/src/vt/group/group_manager.cc +++ b/src/vt/group/group_manager.cc @@ -144,7 +144,7 @@ bool GroupManager::inGroup(GroupType const group) { std::vector GroupManager::GetGroupNodes(GroupType const group_id) const { - if (group_id == default_group) { + if (isGroupDefault(group_id)) { std::vector nodes(theContext()->getNumNodes()); for (NodeType i = 0; i < theContext()->getNumNodes(); ++i) { nodes[i] = i; diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index d22d9f4f36..34a11d45a5 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -41,9 +41,6 @@ //@HEADER */ -#include "vt/collective/reduce/allreduce/recursive_doubling.h" -#include "vt/collective/reduce/allreduce/type.h" -#include #if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H #define INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H diff --git a/tests/unit/collection/test_allreduce_collection.cc b/tests/unit/collection/test_allreduce_collection.cc new file mode 100644 index 0000000000..399b609436 --- /dev/null +++ b/tests/unit/collection/test_allreduce_collection.cc @@ -0,0 +1,171 @@ +/* +//@HEADER +// ***************************************************************************** +// +// test_allreduce_collection.cc +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + + +#include "test_parallel_harness.h" +#include "vt/collective/reduce/allreduce/type.h" +#include "vt/collective/reduce/operators/functors/max_op.h" +#include "vt/scheduler/scheduler.h" +#include "vt/vrt/collection/manager.h" +#include "vt/topos/index/index.h" +#include +#include +#include + +namespace vt { namespace tests { namespace unit { + +struct RecursiveDoublingColl : vt::Collection { + template + void sumAllreduceHan(std::vector result) { + ++counter_; + + ASSERT_EQ(result.size(), size); + + auto const num_nodes = theContext()->getNumNodes(); + auto const num_elems = num_nodes * num_elms_per_node; + auto const expected_val = ((num_elems - 1) * num_elems) / 2; + + auto verify_result = + std::all_of(result.begin(), result.end(), [=](auto const& val) { + return val == expected_val; + }); + + ASSERT_TRUE(verify_result); + } + + template + void maxAllreduceHan(std::vector result) { + ++counter_; + + ASSERT_EQ(result.size(), size); + + auto const num_nodes = theContext()->getNumNodes(); + auto const expected_val = (num_nodes * num_elms_per_node) - 1; + + auto verify_result = std::all_of(result.begin(), result.end(), [=](auto const& val) { + return val == expected_val; + }); + + ASSERT_TRUE(verify_result); + } + + template + void executePlusAllreduce() { + using namespace collective::reduce::allreduce; + auto proxy = this->getCollectionProxy(); + + std::vector payload(size, getIndex().x()); + proxy.allreduce< + ReducerT, &RecursiveDoublingColl::sumAllreduceHan, + collective::PlusOp + >(payload); + } + + template + void executeMaxAllreduce() { + using namespace collective::reduce::allreduce; + auto proxy = this->getCollectionProxy(); + + std::vector payload(size, getIndex().x()); + proxy.allreduce< + ReducerT, &RecursiveDoublingColl::maxAllreduceHan, + collective::MaxOp + >(payload); + } + + int32_t counter_ = 0; +}; + +struct TestAllreduceCollection : TestParallelHarness {}; + +TEST_F(TestAllreduceCollection, test_allreduce_recursive_doubling) { + using namespace vt::collective::reduce::allreduce; + + auto const my_node = theContext()->getNode(); + auto const num_nodes = theContext()->getNumNodes(); + + constexpr auto num_elms_per_node = 3; + auto range = vt::Index1D(int32_t{num_nodes * num_elms_per_node}); + auto proxy = vt::makeCollection("test_collection_allreduce") + .bounds(range) + .bulkInsert() + .wait(); + + constexpr size_t size = 100; + auto const elm = my_node * num_elms_per_node; + auto const& counter = proxy[elm].tryGetLocalPtr()->counter_; + + vt::runInEpochCollective([=] { + proxy.broadcastCollective< + &RecursiveDoublingColl::executePlusAllreduce + >(); + }); + + ASSERT_EQ(counter, 1); + + vt::runInEpochCollective([=] { + proxy.broadcastCollective< + &RecursiveDoublingColl::executeMaxAllreduce + >(); + }); + + ASSERT_EQ(counter, 2); + + vt::runInEpochCollective([=] { + proxy.broadcastCollective< + &RecursiveDoublingColl::executePlusAllreduce + >(); + }); + + ASSERT_EQ(counter, 3); + + vt::runInEpochCollective([=] { + proxy.broadcastCollective< + &RecursiveDoublingColl::executeMaxAllreduce + >(); + }); + + ASSERT_EQ(counter, 4); +} + +}}} // end namespace vt::tests::unit