From 08113bb0783389e1daa6876d673997fd802794a6 Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Sun, 22 Sep 2024 02:45:10 +0200 Subject: [PATCH] #2281: Fixed runtime issues with StateHolder generating extra ID --- .../reduce/allreduce/rabenseifner.impl.h | 12 ++--- .../reduce/allreduce/recursive_doubling.cc | 2 - .../reduce/allreduce/recursive_doubling.h | 4 +- .../allreduce/recursive_doubling.impl.h | 1 + src/vt/collective/reduce/allreduce/state.h | 1 + .../reduce/allreduce/state_holder.h | 54 +++++++++++++++++-- tests/unit/objgroup/test_objgroup.cc | 2 +- tests/unit/objgroup/test_objgroup_common.h | 9 +++- 8 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h index f8194b5965..66e07fc87d 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h @@ -79,8 +79,8 @@ void Rabenseifner::localReduce(size_t id, Args&&... data) { vt_debug_print( terse, allreduce, - "Rabenseifner (this={}): local_col_wait_count_={} ID={} initialized={}\n", - print_ptr(this), state.local_col_wait_count_, id, state.initialized_); + "Rabenseifner(ID = {}) localReduce (this={}): local_col_wait_count_={} initialized={}\n", + id, print_ptr(this), state.local_col_wait_count_, state.initialized_); if (DataHelperT::empty(state.val_)) { @@ -176,6 +176,8 @@ void Rabenseifner::initialize(size_t id, Args&&... data) { } } + state.active_ = true; + vt_debug_print( terse, allreduce, "Rabenseifner initialize: size_ = {} num_steps_ = {} nprocs_pof2_ = {} " @@ -404,12 +406,6 @@ void Rabenseifner::scatterReduceIter(size_t id) { using DataHelperT = DataHelper; auto& state = getState( collection_proxy_, objgroup_proxy_, group_, id); - vt_debug_print( - terse, allreduce, - "Rabenseifner Scatter (Send step {}): s_index_.size() = {} and " - "s_count_.size() = {} ID = {} proxy_={} state = {}\n", - state.scatter_step_, state.s_index_.size(), state.s_count_.size(), id, - proxy_.getProxy(), print_ptr(&state)); auto vdest = vrt_node_ ^ state.scatter_mask_; auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_; diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.cc b/src/vt/collective/reduce/allreduce/recursive_doubling.cc index ef9e787b31..2a7b2778da 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.cc +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.cc @@ -76,7 +76,6 @@ RecursiveDoubling::RecursiveDoubling( RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup) : objgroup_proxy_(objgroup.get()), - local_num_elems_(1), num_nodes_(theContext()->getNumNodes()), this_node_(vt::theContext()->getNode()), is_even_(this_node_ % 2 == 0), @@ -94,7 +93,6 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup) RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group) : group_(group.get()), - local_num_elems_(1), nodes_(theGroup()->GetGroupNodes(group_)), num_nodes_(nodes_.size()), this_node_(vt::theContext()->getNode()), diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.h b/src/vt/collective/reduce/allreduce/recursive_doubling.h index f5fb8591ad..c0690e2602 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.h @@ -245,11 +245,13 @@ struct RecursiveDoubling { vt::objgroup::proxy::Proxy proxy_ = {}; +private: + VirtualProxyType collection_proxy_ = u64empty; ObjGroupProxyType objgroup_proxy_ = u64empty; GroupType group_ = u64empty; - size_t local_num_elems_ = {}; + size_t local_num_elems_ = 1; std::vector nodes_ = {}; NodeType num_nodes_ = {}; diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h index a198926cf8..89e4439591 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.impl.h @@ -101,6 +101,7 @@ void RecursiveDoubling::initialize(size_t id, Args&&... data) { state.val_ = DataT{std::forward(data)...}; state.value_assigned_ = true; + state.active_ = true; vt_debug_print( terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n", diff --git a/src/vt/collective/reduce/allreduce/state.h b/src/vt/collective/reduce/allreduce/state.h index 8d68d004d5..7f0376fc04 100644 --- a/src/vt/collective/reduce/allreduce/state.h +++ b/src/vt/collective/reduce/allreduce/state.h @@ -60,6 +60,7 @@ struct StateBase { int32_t step_ = 0; bool initialized_ = false; bool completed_ = false; + bool active_ = false; }; struct RabensiferBase : StateBase { diff --git a/src/vt/collective/reduce/allreduce/state_holder.h b/src/vt/collective/reduce/allreduce/state_holder.h index 68478480cf..161829935c 100644 --- a/src/vt/collective/reduce/allreduce/state_holder.h +++ b/src/vt/collective/reduce/allreduce/state_holder.h @@ -82,16 +82,64 @@ struct StateHolder { template static size_t getNextID(detail::StrongVrtProxy proxy) { - return active_coll_states_[proxy.get()].size(); + size_t id = 0; + auto& allreducers = active_coll_states_[proxy.get()]; + + if (not allreducers.empty()) { + + // Last element is invalidated (allreduce completed) or not completed + // Generate new ID + if(not allreducers.back() or allreducers.back()->active_) { + id = allreducers.size(); + } + // Most recent state is not active, don't generate new ID + else if(not allreducers.back()->active_){ + id = allreducers.size() - 1; + } + } + + return id; } template static size_t getNextID(detail::StrongObjGroup proxy) { - return active_obj_states_[proxy.get()].size(); + size_t id = 0; + auto& allreducers = active_obj_states_[proxy.get()]; + + if (not allreducers.empty()) { + + // Last element is invalidated (allreduce completed) or not completed + // Generate new ID + if(not allreducers.back() or allreducers.back()->active_) { + id = allreducers.size(); + } + // Most recent state is not active, don't generate new ID + else if(not allreducers.back()->active_){ + id = allreducers.size() - 1; + } + } + + return id; } static size_t getNextID(detail::StrongGroup group) { - return active_grp_states_[group.get()].size(); + size_t id = 0; + auto& allreducers = active_grp_states_[group.get()]; + + if (not allreducers.empty()) { + + // Last element is invalidated (allreduce completed) or not completed + // Generate new ID + if(not allreducers.back() or allreducers.back()->active_) { + id = allreducers.size(); + } + // Most recent state is not active, don't generate new ID + else if(not allreducers.back()->active_){ + id = allreducers.size() - 1; + } + } + + return id; } static void clearSingle(detail::StrongVrtProxy proxy, size_t idx) { diff --git a/tests/unit/objgroup/test_objgroup.cc b/tests/unit/objgroup/test_objgroup.cc index b9cb30d14b..b95e443794 100644 --- a/tests/unit/objgroup/test_objgroup.cc +++ b/tests/unit/objgroup/test_objgroup.cc @@ -365,7 +365,7 @@ TEST_F(TestObjGroupKokkos, test_proxy_allreduce_kokkos) { } kokkos_proxy.allreduce< - &MyObjA::verifyAllredView, PlusOp, reduce::allreduce::RabenseifnerT>( + &MyObjA::verifyAllredView, PlusOp, reduce::allreduce::RabenseifnerT>( view); }); diff --git a/tests/unit/objgroup/test_objgroup_common.h b/tests/unit/objgroup/test_objgroup_common.h index 80921b9111..09cead4be4 100644 --- a/tests/unit/objgroup/test_objgroup_common.h +++ b/tests/unit/objgroup/test_objgroup_common.h @@ -47,6 +47,7 @@ #include "test_parallel_harness.h" #include "vt/collective/reduce/operators/default_msg.h" #include "vt/collective/reduce/allreduce/data_handler.h" +#include "vt/utils/kokkos/exec_space.h" #include #include @@ -151,13 +152,17 @@ struct MyObjA { } #if MAGISTRATE_KOKKOS_ENABLED - void verifyAllredView(Kokkos::View view) { + template + void verifyAllredView(Kokkos::View view) { auto final_size = view.extent(0); EXPECT_EQ(final_size, 256); auto n = vt::theContext()->getNumNodes(); auto const total_sum = n * (n - 1) / 2; - Kokkos::parallel_for("InitView", view.extent(0), KOKKOS_LAMBDA(const int i) { + using ExecSpace = typename utils::kokkos::AssociatedExecSpace::type; + + Kokkos::RangePolicy policy(0, view.extent(0)); + Kokkos::parallel_for("InitView", policy, KOKKOS_LAMBDA(const int i) { EXPECT_EQ(view(i), total_sum); });