Skip to content

Commit

Permalink
#2281: Initial work for AllreduceHolder
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 20, 2024
1 parent 852e53c commit ea24fc1
Show file tree
Hide file tree
Showing 10 changed files with 351 additions and 82 deletions.
88 changes: 88 additions & 0 deletions src/vt/collective/reduce/allreduce/allreduce_holder.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "allreduce_holder.h"
#include "vt/objgroup/manager.h"

namespace vt::collective::reduce::allreduce {

objgroup::proxy::Proxy<Rabenseifner> AllreduceHolder::addRabensifnerAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems) {
auto const coll_proxy = strong_proxy.get();

auto obj_proxy = theObjGroup()->makeCollective<Rabenseifner>(
"rabenseifer_allreducer", strong_proxy, strong_group, num_elems);

col_reducers_[coll_proxy].first = obj_proxy.getProxy();

fmt::print(
"Adding new Rabenseifner reducer for collection={:x}\n", coll_proxy);

return obj_proxy;
}

objgroup::proxy::Proxy<RecursiveDoubling>
AllreduceHolder::addRecursiveDoublingAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems) {
auto const coll_proxy = strong_proxy.get();
auto obj_proxy = theObjGroup()->makeCollective<RecursiveDoubling>(
"recursive_doubling_allreducer", strong_proxy, strong_group, num_elems);

col_reducers_[coll_proxy].second = obj_proxy.getProxy();
fmt::print(
"Adding new RecursiveDoubling reducer for collection={:x}\n", coll_proxy);

return obj_proxy;
}

objgroup::proxy::Proxy<Rabenseifner>
AllreduceHolder::addRabensifnerAllreducer(detail::StrongGroup strong_group) {
auto const group = strong_group.get();

auto obj_proxy = theObjGroup()->makeCollective<Rabenseifner>(
"rabenseifer_allreducer", strong_group);

group_reducers_[group].first = obj_proxy.getProxy();

fmt::print(
"Adding new Rabenseifner reducer for group={:x} Size={}\n", group,
group_reducers_.size());

return obj_proxy;
}

objgroup::proxy::Proxy<RecursiveDoubling>
AllreduceHolder::addRecursiveDoublingAllreducer(
detail::StrongGroup strong_group) {
auto const group = strong_group.get();

auto obj_proxy = theObjGroup()->makeCollective<RecursiveDoubling>(
"recursive_doubling_allreducer", strong_group);

fmt::print("Adding new RecursiveDoubling reducer for group={:x}\n", group);

group_reducers_[group].second = obj_proxy.getProxy();

return obj_proxy;
}

void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) {
auto const key = strong_proxy.get();

auto it = col_reducers_.find(key);

if (it != col_reducers_.end()) {
col_reducers_.erase(key);
}
}

void AllreduceHolder::remove(detail::StrongGroup strong_group) {
auto const key = strong_group.get();

auto it = group_reducers_.find(key);

if (it != group_reducers_.end()) {
group_reducers_.erase(key);
}
}

} // namespace vt::collective::reduce::allreduce
151 changes: 151 additions & 0 deletions src/vt/collective/reduce/allreduce/allreduce_holder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
//@HEADER
// *****************************************************************************
//
// allreduce_holder.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H

#include "vt/configs/types/types_type.h"
#include "vt/collective/reduce/allreduce/type.h"
#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/collective/reduce/allreduce/recursive_doubling.h"
#include "vt/configs/types/types_sentinels.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"

#include <type_traits>
#include <unordered_map>

namespace vt::collective::reduce::allreduce {

struct AllreduceHolder {
using RabenseifnerProxy = ObjGroupProxyType;
using RecursiveDoublingProxy = ObjGroupProxyType;

template <typename ReducerT>
static auto getAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems) {
auto const coll_proxy = strong_proxy.get();

if (col_reducers_.find(coll_proxy) == col_reducers_.end()) {
col_reducers_[coll_proxy] = {u64empty, u64empty};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto untyped_proxy = col_reducers_.at(coll_proxy).first;
if (untyped_proxy == u64empty) {
return addRabensifnerAllreducer(strong_proxy, strong_group, num_elems);
} else {
return static_cast<vt::objgroup::proxy::Proxy<Rabenseifner>>(
untyped_proxy);
}
} else {
auto untyped_proxy = col_reducers_.at(coll_proxy).second;
if (untyped_proxy == u64empty) {
return addRecursiveDoublingAllreducer(
strong_proxy, strong_group, num_elems);
} else {
return static_cast<vt::objgroup::proxy::Proxy<RecursiveDoubling>>(
untyped_proxy);
}
}
}

template <typename ReducerT>
static auto getAllreducer(detail::StrongGroup strong_group) {
auto const group = strong_group.get();

if (auto it = group_reducers_.find(group); it == group_reducers_.end()) {
group_reducers_[group] = {u64empty, u64empty};
}

if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
auto untyped_proxy = group_reducers_.at(group).first;
if (untyped_proxy == u64empty) {
return addRabensifnerAllreducer(strong_group);
} else {
return static_cast<vt::objgroup::proxy::Proxy<Rabenseifner>>(
untyped_proxy);
}
} else {
auto untyped_proxy = group_reducers_.at(group).second;
if (untyped_proxy == u64empty) {
return addRecursiveDoublingAllreducer(strong_group);
} else {
return static_cast<vt::objgroup::proxy::Proxy<RecursiveDoubling>>(
untyped_proxy);
}
}
}

static objgroup::proxy::Proxy<Rabenseifner> addRabensifnerAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems);

static objgroup::proxy::Proxy<RecursiveDoubling>
addRecursiveDoublingAllreducer(
detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group,
size_t num_elems);

static objgroup::proxy::Proxy<Rabenseifner>
addRabensifnerAllreducer(detail::StrongGroup strong_group);
static objgroup::proxy::Proxy<RecursiveDoubling>
addRecursiveDoublingAllreducer(detail::StrongGroup strong_group);

static void remove(detail::StrongVrtProxy strong_proxy);
static void remove(detail::StrongGroup strong_group);

static inline std::unordered_map<
VirtualProxyType, std::pair<RabenseifnerProxy, RecursiveDoublingProxy>>
col_reducers_ = {};
static inline std::unordered_map<
GroupType, std::pair<RabenseifnerProxy, RecursiveDoublingProxy>>
group_reducers_ = {};
static inline std::unordered_map<
ObjGroupProxyType, std::pair<RabenseifnerProxy, RecursiveDoublingProxy>>
objgroup_reducers_ = {};
};

} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H*/
4 changes: 3 additions & 1 deletion src/vt/collective/reduce/allreduce/rabenseifner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@

#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/configs/error/config_assert.h"
#include "vt/group/group_manager.h"

namespace vt::collective::reduce::allreduce {

Expand Down Expand Up @@ -166,7 +167,8 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup)

Rabenseifner::~Rabenseifner() {
if (collection_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});
// StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});

} else if (objgroup_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_});
} else {
Expand Down
2 changes: 0 additions & 2 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
#include "vt/context/context.h"
#include "vt/configs/error/config_assert.h"
#include "vt/group/global/group_default.h"
#include "vt/group/group_manager.h"
#include "vt/group/group_info.h"
#include "vt/configs/types/types_sentinels.h"
#include "vt/registry/auto/auto_registry.h"
#include "vt/utils/fntraits/fntraits.h"
Expand Down
2 changes: 1 addition & 1 deletion src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group)

RecursiveDoubling::~RecursiveDoubling() {
if (collection_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});
// StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_});
} else if (objgroup_proxy_ != u64empty) {
StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_});
} else {
Expand Down
13 changes: 9 additions & 4 deletions src/vt/group/group_manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@
#include "vt/messaging/active.h"
#include "vt/activefn/activefn.h"
#include "vt/group/group_info.h"
#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/collective/reduce/allreduce/allreduce_holder.h"
#include "vt/objgroup/manager.h"
#include "vt/pipe/pipe_manager.impl.h"
#include "vt/collective/reduce/allreduce/type.h"

namespace vt { namespace group {

Expand Down Expand Up @@ -167,10 +168,10 @@ void GroupManager::allreduce(GroupType group, Args&&... args) {

using DataT = std::tuple_element_t<0, typename FuncTraits<decltype(f)>::TupleType>;

using Reducer = Rabenseifner;
// using Reducer = Rabenseifner;
auto const strong_group = collective::reduce::detail::StrongGroup{group};
// TODO; Save the proxy so it can be deleted afterwards
auto proxy = theObjGroup()->makeCollective<Reducer>("reducer", strong_group);
auto proxy =
AllreduceHolder::getAllreducer<RabenseifnerT>(strong_group);

if (iter->second->is_in_group) {
auto const this_node = theContext()->getNode();
Expand All @@ -180,6 +181,10 @@ void GroupManager::allreduce(GroupType group, Args&&... args) {
ptr->template setFinalHandler<DataT>(theCB()->makeSend<f>(this_node), id);
ptr->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}

addCleanupAction([strong_group] {
AllreduceHolder::remove(strong_group);
});
}

}} /* end namespace vt::group */
Expand Down
1 change: 1 addition & 0 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
#include "vt/collective/reduce/allreduce/helpers.h"
#include "vt/collective/reduce/scoping/strong_types.h"
#include "vt/collective/reduce/allreduce/state_holder.h"
#include "vt/pipe/pipe_manager.h"

#include <utility>
#include <array>
Expand Down
9 changes: 9 additions & 0 deletions src/vt/vrt/collection/holders/typeless_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

#include "vt/vrt/collection/holders/typeless_holder.h"
#include "vt/scheduler/scheduler.h"
#include "vt/collective/reduce/allreduce/state_holder.h"
#include "vt/collective/reduce/allreduce/allreduce_holder.h"

namespace vt { namespace vrt { namespace collection {

Expand Down Expand Up @@ -73,6 +75,13 @@ void TypelessHolder::destroyCollection(VirtualProxyType const proxy) {
labels_.erase(iter);
}
}

vt::collective::reduce::allreduce::StateHolder::clearAll(
vt::collective::reduce::detail::StrongVrtProxy{proxy});

vt::collective::reduce::allreduce::AllreduceHolder::remove(
vt::collective::reduce::detail::StrongVrtProxy{proxy}
);
}

void TypelessHolder::invokeAllGroupConstructors() {
Expand Down
Loading

0 comments on commit ea24fc1

Please sign in to comment.