Skip to content

Commit

Permalink
#2281: Fixed runtime issues with StateHolder generating extra ID
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 22, 2024
1 parent 455829a commit 08113bb
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 17 deletions.
12 changes: 4 additions & 8 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ void Rabenseifner::localReduce(size_t id, Args&&... data) {

vt_debug_print(
terse, allreduce,
"Rabenseifner (this={}): local_col_wait_count_={} ID={} initialized={}\n",
print_ptr(this), state.local_col_wait_count_, id, state.initialized_);
"Rabenseifner(ID = {}) localReduce (this={}): local_col_wait_count_={} initialized={}\n",
id, print_ptr(this), state.local_col_wait_count_, state.initialized_);


if (DataHelperT::empty(state.val_)) {
Expand Down Expand Up @@ -176,6 +176,8 @@ void Rabenseifner::initialize(size_t id, Args&&... data) {
}
}

state.active_ = true;

vt_debug_print(
terse, allreduce,
"Rabenseifner initialize: size_ = {} num_steps_ = {} nprocs_pof2_ = {} "
Expand Down Expand Up @@ -404,12 +406,6 @@ void Rabenseifner::scatterReduceIter(size_t id) {
using DataHelperT = DataHelper<Scalar, DataT>;
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
vt_debug_print(
terse, allreduce,
"Rabenseifner Scatter (Send step {}): s_index_.size() = {} and "
"s_count_.size() = {} ID = {} proxy_={} state = {}\n",
state.scatter_step_, state.s_index_.size(), state.s_count_.size(), id,
proxy_.getProxy(), print_ptr(&state));

auto vdest = vrt_node_ ^ state.scatter_mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;
Expand Down
2 changes: 0 additions & 2 deletions src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ RecursiveDoubling::RecursiveDoubling(

RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup)
: objgroup_proxy_(objgroup.get()),
local_num_elems_(1),
num_nodes_(theContext()->getNumNodes()),
this_node_(vt::theContext()->getNode()),
is_even_(this_node_ % 2 == 0),
Expand All @@ -94,7 +93,6 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongObjGroup objgroup)

RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group)
: group_(group.get()),
local_num_elems_(1),
nodes_(theGroup()->GetGroupNodes(group_)),
num_nodes_(nodes_.size()),
this_node_(vt::theContext()->getNode()),
Expand Down
4 changes: 3 additions & 1 deletion src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,13 @@ struct RecursiveDoubling {

vt::objgroup::proxy::Proxy<RecursiveDoubling> proxy_ = {};

private:

VirtualProxyType collection_proxy_ = u64empty;
ObjGroupProxyType objgroup_proxy_ = u64empty;
GroupType group_ = u64empty;

size_t local_num_elems_ = {};
size_t local_num_elems_ = 1;

std::vector<NodeType> nodes_ = {};
NodeType num_nodes_ = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ void RecursiveDoubling::initialize(size_t id, Args&&... data) {

state.val_ = DataT{std::forward<Args>(data)...};
state.value_assigned_ = true;
state.active_ = true;

vt_debug_print(
terse, allreduce, "RecursiveDoubling Initialize: size {} ID {}\n",
Expand Down
1 change: 1 addition & 0 deletions src/vt/collective/reduce/allreduce/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct StateBase {
int32_t step_ = 0;
bool initialized_ = false;
bool completed_ = false;
bool active_ = false;
};

struct RabensiferBase : StateBase {
Expand Down
54 changes: 51 additions & 3 deletions src/vt/collective/reduce/allreduce/state_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,64 @@ struct StateHolder {

template <typename ReducerT>
static size_t getNextID(detail::StrongVrtProxy proxy) {
return active_coll_states_[proxy.get()].size();
size_t id = 0;
auto& allreducers = active_coll_states_[proxy.get()];

if (not allreducers.empty()) {

// Last element is invalidated (allreduce completed) or not completed
// Generate new ID
if(not allreducers.back() or allreducers.back()->active_) {
id = allreducers.size();
}
// Most recent state is not active, don't generate new ID
else if(not allreducers.back()->active_){
id = allreducers.size() - 1;
}
}

return id;
}

template <typename ReducerT>
static size_t getNextID(detail::StrongObjGroup proxy) {
return active_obj_states_[proxy.get()].size();
size_t id = 0;
auto& allreducers = active_obj_states_[proxy.get()];

if (not allreducers.empty()) {

// Last element is invalidated (allreduce completed) or not completed
// Generate new ID
if(not allreducers.back() or allreducers.back()->active_) {
id = allreducers.size();
}
// Most recent state is not active, don't generate new ID
else if(not allreducers.back()->active_){
id = allreducers.size() - 1;
}
}

return id;
}

static size_t getNextID(detail::StrongGroup group) {
return active_grp_states_[group.get()].size();
size_t id = 0;
auto& allreducers = active_grp_states_[group.get()];

if (not allreducers.empty()) {

// Last element is invalidated (allreduce completed) or not completed
// Generate new ID
if(not allreducers.back() or allreducers.back()->active_) {
id = allreducers.size();
}
// Most recent state is not active, don't generate new ID
else if(not allreducers.back()->active_){
id = allreducers.size() - 1;
}
}

return id;
}

static void clearSingle(detail::StrongVrtProxy proxy, size_t idx) {
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/objgroup/test_objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ TEST_F(TestObjGroupKokkos, test_proxy_allreduce_kokkos) {
}

kokkos_proxy.allreduce<
&MyObjA::verifyAllredView, PlusOp, reduce::allreduce::RabenseifnerT>(
&MyObjA::verifyAllredView<Kokkos::HostSpace>, PlusOp, reduce::allreduce::RabenseifnerT>(
view);
});

Expand Down
9 changes: 7 additions & 2 deletions tests/unit/objgroup/test_objgroup_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "test_parallel_harness.h"
#include "vt/collective/reduce/operators/default_msg.h"
#include "vt/collective/reduce/allreduce/data_handler.h"
#include "vt/utils/kokkos/exec_space.h"

#include <numeric>
#include <vector>
Expand Down Expand Up @@ -151,13 +152,17 @@ struct MyObjA {
}

#if MAGISTRATE_KOKKOS_ENABLED
void verifyAllredView(Kokkos::View<float*, Kokkos::HostSpace> view) {
template <typename MemorySpace>
void verifyAllredView(Kokkos::View<float*, MemorySpace> view) {
auto final_size = view.extent(0);
EXPECT_EQ(final_size, 256);

auto n = vt::theContext()->getNumNodes();
auto const total_sum = n * (n - 1) / 2;
Kokkos::parallel_for("InitView", view.extent(0), KOKKOS_LAMBDA(const int i) {
using ExecSpace = typename utils::kokkos::AssociatedExecSpace<MemorySpace>::type;

Kokkos::RangePolicy<ExecSpace> policy(0, view.extent(0));
Kokkos::parallel_for("InitView", policy, KOKKOS_LAMBDA(const int i) {
EXPECT_EQ(view(i), total_sum);
});

Expand Down

0 comments on commit 08113bb

Please sign in to comment.