diff --git a/src/vt/collective/reduce/allreduce/allreduce_holder.cc b/src/vt/collective/reduce/allreduce/allreduce_holder.cc new file mode 100644 index 0000000000..3da7b688ef --- /dev/null +++ b/src/vt/collective/reduce/allreduce/allreduce_holder.cc @@ -0,0 +1,88 @@ +#include "allreduce_holder.h" +#include "vt/objgroup/manager.h" + +namespace vt::collective::reduce::allreduce { + +objgroup::proxy::Proxy AllreduceHolder::addRabensifnerAllreducer( + detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, + size_t num_elems) { + auto const coll_proxy = strong_proxy.get(); + + auto obj_proxy = theObjGroup()->makeCollective( + "rabenseifer_allreducer", strong_proxy, strong_group, num_elems); + + col_reducers_[coll_proxy].first = obj_proxy.getProxy(); + + fmt::print( + "Adding new Rabenseifner reducer for collection={:x}\n", coll_proxy); + + return obj_proxy; +} + +objgroup::proxy::Proxy +AllreduceHolder::addRecursiveDoublingAllreducer( + detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, + size_t num_elems) { + auto const coll_proxy = strong_proxy.get(); + auto obj_proxy = theObjGroup()->makeCollective( + "recursive_doubling_allreducer", strong_proxy, strong_group, num_elems); + + col_reducers_[coll_proxy].second = obj_proxy.getProxy(); + fmt::print( + "Adding new RecursiveDoubling reducer for collection={:x}\n", coll_proxy); + + return obj_proxy; +} + +objgroup::proxy::Proxy +AllreduceHolder::addRabensifnerAllreducer(detail::StrongGroup strong_group) { + auto const group = strong_group.get(); + + auto obj_proxy = theObjGroup()->makeCollective( + "rabenseifer_allreducer", strong_group); + + group_reducers_[group].first = obj_proxy.getProxy(); + + fmt::print( + "Adding new Rabenseifner reducer for group={:x} Size={}\n", group, + group_reducers_.size()); + + return obj_proxy; +} + +objgroup::proxy::Proxy +AllreduceHolder::addRecursiveDoublingAllreducer( + detail::StrongGroup strong_group) { + auto const group = strong_group.get(); + + auto obj_proxy = theObjGroup()->makeCollective( + "recursive_doubling_allreducer", strong_group); + + fmt::print("Adding new RecursiveDoubling reducer for group={:x}\n", group); + + group_reducers_[group].second = obj_proxy.getProxy(); + + return obj_proxy; +} + +void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) { + auto const key = strong_proxy.get(); + + auto it = col_reducers_.find(key); + + if (it != col_reducers_.end()) { + col_reducers_.erase(key); + } +} + +void AllreduceHolder::remove(detail::StrongGroup strong_group) { + auto const key = strong_group.get(); + + auto it = group_reducers_.find(key); + + if (it != group_reducers_.end()) { + group_reducers_.erase(key); + } +} + +} // namespace vt::collective::reduce::allreduce diff --git a/src/vt/collective/reduce/allreduce/allreduce_holder.h b/src/vt/collective/reduce/allreduce/allreduce_holder.h new file mode 100644 index 0000000000..b60a232a0c --- /dev/null +++ b/src/vt/collective/reduce/allreduce/allreduce_holder.h @@ -0,0 +1,151 @@ +/* +//@HEADER +// ***************************************************************************** +// +// allreduce_holder.h +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H +#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H + +#include "vt/configs/types/types_type.h" +#include "vt/collective/reduce/allreduce/type.h" +#include "vt/collective/reduce/scoping/strong_types.h" +#include "vt/collective/reduce/allreduce/rabenseifner.h" +#include "vt/collective/reduce/allreduce/recursive_doubling.h" +#include "vt/configs/types/types_sentinels.h" +#include "vt/objgroup/proxy/proxy_objgroup.h" + +#include +#include + +namespace vt::collective::reduce::allreduce { + +struct AllreduceHolder { + using RabenseifnerProxy = ObjGroupProxyType; + using RecursiveDoublingProxy = ObjGroupProxyType; + + template + static auto getAllreducer( + detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, + size_t num_elems) { + auto const coll_proxy = strong_proxy.get(); + + if (col_reducers_.find(coll_proxy) == col_reducers_.end()) { + col_reducers_[coll_proxy] = {u64empty, u64empty}; + } + + if constexpr (std::is_same_v) { + auto untyped_proxy = col_reducers_.at(coll_proxy).first; + if (untyped_proxy == u64empty) { + return addRabensifnerAllreducer(strong_proxy, strong_group, num_elems); + } else { + return static_cast>( + untyped_proxy); + } + } else { + auto untyped_proxy = col_reducers_.at(coll_proxy).second; + if (untyped_proxy == u64empty) { + return addRecursiveDoublingAllreducer( + strong_proxy, strong_group, num_elems); + } else { + return static_cast>( + untyped_proxy); + } + } + } + + template + static auto getAllreducer(detail::StrongGroup strong_group) { + auto const group = strong_group.get(); + + if (auto it = group_reducers_.find(group); it == group_reducers_.end()) { + group_reducers_[group] = {u64empty, u64empty}; + } + + if constexpr (std::is_same_v) { + auto untyped_proxy = group_reducers_.at(group).first; + if (untyped_proxy == u64empty) { + return addRabensifnerAllreducer(strong_group); + } else { + return static_cast>( + untyped_proxy); + } + } else { + auto untyped_proxy = group_reducers_.at(group).second; + if (untyped_proxy == u64empty) { + return addRecursiveDoublingAllreducer(strong_group); + } else { + return static_cast>( + untyped_proxy); + } + } + } + + static objgroup::proxy::Proxy addRabensifnerAllreducer( + detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, + size_t num_elems); + + static objgroup::proxy::Proxy + addRecursiveDoublingAllreducer( + detail::StrongVrtProxy strong_proxy, detail::StrongGroup strong_group, + size_t num_elems); + + static objgroup::proxy::Proxy + addRabensifnerAllreducer(detail::StrongGroup strong_group); + static objgroup::proxy::Proxy + addRecursiveDoublingAllreducer(detail::StrongGroup strong_group); + + static void remove(detail::StrongVrtProxy strong_proxy); + static void remove(detail::StrongGroup strong_group); + + static inline std::unordered_map< + VirtualProxyType, std::pair> + col_reducers_ = {}; + static inline std::unordered_map< + GroupType, std::pair> + group_reducers_ = {}; + static inline std::unordered_map< + ObjGroupProxyType, std::pair> + objgroup_reducers_ = {}; +}; + +} // namespace vt::collective::reduce::allreduce + +#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_HOLDER_H*/ diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.cc b/src/vt/collective/reduce/allreduce/rabenseifner.cc index b4680d76cb..9ab4443e31 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.cc +++ b/src/vt/collective/reduce/allreduce/rabenseifner.cc @@ -43,6 +43,7 @@ #include "vt/collective/reduce/allreduce/rabenseifner.h" #include "vt/configs/error/config_assert.h" +#include "vt/group/group_manager.h" namespace vt::collective::reduce::allreduce { @@ -166,7 +167,8 @@ Rabenseifner::Rabenseifner(detail::StrongObjGroup objgroup) Rabenseifner::~Rabenseifner() { if (collection_proxy_ != u64empty) { - StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); + // StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); + } else if (objgroup_proxy_ != u64empty) { StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_}); } else { diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h index dd22cce1ed..f8194b5965 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h @@ -49,8 +49,6 @@ #include "vt/context/context.h" #include "vt/configs/error/config_assert.h" #include "vt/group/global/group_default.h" -#include "vt/group/group_manager.h" -#include "vt/group/group_info.h" #include "vt/configs/types/types_sentinels.h" #include "vt/registry/auto/auto_registry.h" #include "vt/utils/fntraits/fntraits.h" diff --git a/src/vt/collective/reduce/allreduce/recursive_doubling.cc b/src/vt/collective/reduce/allreduce/recursive_doubling.cc index 9bb7f358ed..ae6fcd5029 100644 --- a/src/vt/collective/reduce/allreduce/recursive_doubling.cc +++ b/src/vt/collective/reduce/allreduce/recursive_doubling.cc @@ -129,7 +129,7 @@ RecursiveDoubling::RecursiveDoubling(detail::StrongGroup group) RecursiveDoubling::~RecursiveDoubling() { if (collection_proxy_ != u64empty) { - StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); + // StateHolder::clearAll(detail::StrongVrtProxy{collection_proxy_}); } else if (objgroup_proxy_ != u64empty) { StateHolder::clearAll(detail::StrongObjGroup{objgroup_proxy_}); } else { diff --git a/src/vt/group/group_manager.impl.h b/src/vt/group/group_manager.impl.h index e67e743b97..78dfe91f5b 100644 --- a/src/vt/group/group_manager.impl.h +++ b/src/vt/group/group_manager.impl.h @@ -55,9 +55,10 @@ #include "vt/messaging/active.h" #include "vt/activefn/activefn.h" #include "vt/group/group_info.h" -#include "vt/collective/reduce/allreduce/rabenseifner.h" +#include "vt/collective/reduce/allreduce/allreduce_holder.h" #include "vt/objgroup/manager.h" #include "vt/pipe/pipe_manager.impl.h" +#include "vt/collective/reduce/allreduce/type.h" namespace vt { namespace group { @@ -167,10 +168,10 @@ void GroupManager::allreduce(GroupType group, Args&&... args) { using DataT = std::tuple_element_t<0, typename FuncTraits::TupleType>; - using Reducer = Rabenseifner; + // using Reducer = Rabenseifner; auto const strong_group = collective::reduce::detail::StrongGroup{group}; - // TODO; Save the proxy so it can be deleted afterwards - auto proxy = theObjGroup()->makeCollective("reducer", strong_group); + auto proxy = + AllreduceHolder::getAllreducer(strong_group); if (iter->second->is_in_group) { auto const this_node = theContext()->getNode(); @@ -180,6 +181,10 @@ void GroupManager::allreduce(GroupType group, Args&&... args) { ptr->template setFinalHandler(theCB()->makeSend(this_node), id); ptr->template localReduce(id, std::forward(args)...); } + + addCleanupAction([strong_group] { + AllreduceHolder::remove(strong_group); + }); } }} /* end namespace vt::group */ diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index b9dd281935..6833f06095 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -64,6 +64,7 @@ #include "vt/collective/reduce/allreduce/helpers.h" #include "vt/collective/reduce/scoping/strong_types.h" #include "vt/collective/reduce/allreduce/state_holder.h" +#include "vt/pipe/pipe_manager.h" #include #include diff --git a/src/vt/vrt/collection/holders/typeless_holder.cc b/src/vt/vrt/collection/holders/typeless_holder.cc index eb3f9f7565..6d483e25e9 100644 --- a/src/vt/vrt/collection/holders/typeless_holder.cc +++ b/src/vt/vrt/collection/holders/typeless_holder.cc @@ -43,6 +43,8 @@ #include "vt/vrt/collection/holders/typeless_holder.h" #include "vt/scheduler/scheduler.h" +#include "vt/collective/reduce/allreduce/state_holder.h" +#include "vt/collective/reduce/allreduce/allreduce_holder.h" namespace vt { namespace vrt { namespace collection { @@ -73,6 +75,13 @@ void TypelessHolder::destroyCollection(VirtualProxyType const proxy) { labels_.erase(iter); } } + + vt::collective::reduce::allreduce::StateHolder::clearAll( + vt::collective::reduce::detail::StrongVrtProxy{proxy}); + + vt::collective::reduce::allreduce::AllreduceHolder::remove( + vt::collective::reduce::detail::StrongVrtProxy{proxy} + ); } void TypelessHolder::invokeAllGroupConstructors() { diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index da69d64ef4..1f5dc87c88 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -79,6 +79,7 @@ #include "vt/scheduler/scheduler.h" #include "vt/phase/phase_manager.h" #include "vt/runnable/make_runnable.h" +#include "vt/collective/reduce/allreduce/allreduce_holder.h" #include #include @@ -918,80 +919,94 @@ messaging::PendingSend CollectionManager::reduceLocal( auto cb = vt::theCB()->makeCallbackBcastCollectiveProxy(proxy); - if constexpr (std::is_same_v) { - using Reducer = collective::reduce::allreduce::Rabenseifner; - if (auto reducer = rabenseifner_reducers_.find(col_proxy); - reducer == rabenseifner_reducers_.end()) { - if (use_group) { - // theGroup()->allreduce(group, ); - } else { - vt_debug_print( - terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id); - auto obj_proxy = theObjGroup()->makeCollective( - "reducer", collective::reduce::detail::StrongVrtProxy{col_proxy}, - collective::reduce::detail::StrongGroup{group}, num_elms); - - rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy(); - auto* obj = obj_proxy[theContext()->getNode()].get(); - obj->proxy_ = obj_proxy; - - obj->template setFinalHandler(cb, id); - obj->template localReduce(id, std::forward(args)...); - } - } else { - if (use_group) { - // theGroup()->allreduce(group, ); - } else { - vt_debug_print( - terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id); - auto obj_proxy = - reducer->second; // rabenseifner_reducers_.at(col_proxy); - auto typed_proxy = - static_cast>(obj_proxy); - auto* obj = typed_proxy[theContext()->getNode()].get(); - - obj->template setFinalHandler(cb, id); - obj->template localReduce(id, std::forward(args)...); - } - } + if (use_group) { + // theGroup()->allreduce(group, ); } else { - using Reducer = collective::reduce::allreduce::RecursiveDoubling; - if (auto reducer = recursive_doubling_reducers_.find(col_proxy); - reducer == recursive_doubling_reducers_.end()) { - if (use_group) { - // theGroup()->allreduce(group, ); - } else { - vt_debug_print( - terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id); - auto obj_proxy = theObjGroup()->makeCollective( - "reducer", collective::reduce::detail::StrongVrtProxy{col_proxy}, - collective::reduce::detail::StrongGroup{group}, num_elms); - - recursive_doubling_reducers_[col_proxy] = obj_proxy.getProxy(); - auto* obj = obj_proxy[theContext()->getNode()].get(); - obj->proxy_ = obj_proxy; - - obj->template setFinalHandler(cb, id); - obj->template localReduce(id, std::forward(args)...); - } - } else { - if (use_group) { - // theGroup()->allreduce(group, ); - } else { - vt_debug_print( - terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id); - auto obj_proxy = - reducer->second; // rabenseifner_reducers_.at(col_proxy); - auto typed_proxy = - static_cast>(obj_proxy); - auto* obj = typed_proxy[theContext()->getNode()].get(); - - obj->template setFinalHandler(cb, id); - obj->template localReduce(id, std::forward(args)...); - } - } + auto obj_proxy = AllreduceHolder::getAllreducer( + collective::reduce::detail::StrongVrtProxy{col_proxy}, + collective::reduce::detail::StrongGroup{group}, num_elms); + + auto* obj = obj_proxy[theContext()->getNode()].get(); + obj->proxy_ = obj_proxy; + + obj->template setFinalHandler(cb, id); + obj->template localReduce(id, std::forward(args)...); } + // if constexpr (std::is_same_v) { + // using Reducer = collective::reduce::allreduce::Rabenseifner; + // if (auto reducer = rabenseifner_reducers_.find(col_proxy); + // reducer == rabenseifner_reducers_.end()) { + // if (use_group) { + // // theGroup()->allreduce(group, ); + // } else { + // vt_debug_print( + // terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id); + // auto obj_proxy = theObjGroup()->makeCollective( + // "reducer", collective::reduce::detail::StrongVrtProxy{col_proxy}, + // collective::reduce::detail::StrongGroup{group}, num_elms); + + // rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy(); + // auto* obj = obj_proxy[theContext()->getNode()].get(); + // obj->proxy_ = obj_proxy; + + // obj->template setFinalHandler(cb, id); + // obj->template localReduce(id, std::forward(args)...); + // } + // } else { + // if (use_group) { + // // theGroup()->allreduce(group, ); + // } else { + // vt_debug_print( + // terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id); + // auto obj_proxy = + // reducer->second; // rabenseifner_reducers_.at(col_proxy); + // auto typed_proxy = + // static_cast>(obj_proxy); + // auto* obj = typed_proxy[theContext()->getNode()].get(); + + // obj->template setFinalHandler(cb, id); + // obj->template localReduce(id, std::forward(args)...); + // } + // } + // } else { + // using Reducer = collective::reduce::allreduce::RecursiveDoubling; + // if (auto reducer = recursive_doubling_reducers_.find(col_proxy); + // reducer == recursive_doubling_reducers_.end()) { + // if (use_group) { + // // theGroup()->allreduce(group, ); + // } else { + // vt_debug_print( + // terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id); + // auto obj_proxy = theObjGroup()->makeCollective( + // "reducer", collective::reduce::detail::StrongVrtProxy{col_proxy}, + // collective::reduce::detail::StrongGroup{group}, num_elms); + + // recursive_doubling_reducers_[col_proxy] = obj_proxy.getProxy(); + // auto* obj = obj_proxy[theContext()->getNode()].get(); + // obj->proxy_ = obj_proxy; + + // obj->template setFinalHandler(cb, id); + // obj->template localReduce(id, std::forward(args)...); + // } + // } else { + // if (use_group) { + // // theGroup()->allreduce(group, ); + // } else { + // vt_debug_print( + // terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id); + // auto obj_proxy = + // reducer->second; // rabenseifner_reducers_.at(col_proxy); + // auto typed_proxy = + // static_cast>(obj_proxy); + // auto* obj = typed_proxy[theContext()->getNode()].get(); + + // obj->template setFinalHandler(cb, id); + // obj->template localReduce(id, std::forward(args)...); + // } + // } + // } + return messaging::PendingSend{nullptr}; } diff --git a/tests/unit/objgroup/test_objgroup.cc b/tests/unit/objgroup/test_objgroup.cc index 3d5592992e..2b14c375b6 100644 --- a/tests/unit/objgroup/test_objgroup.cc +++ b/tests/unit/objgroup/test_objgroup.cc @@ -360,9 +360,9 @@ TEST_F(TestObjGroupKokkos, test_proxy_allreduce_kokkos) { runInEpochCollective([&] { Kokkos::View view("view", 256); - Kokkos::parallel_for( - "InitView", view.extent(0), - KOKKOS_LAMBDA(const int i) { view(i) = static_cast(my_node); }); + for (uint32_t i = 0; i < view.extent(0); i++) { + view(i) = static_cast(my_node); + } kokkos_proxy.allreduce< &MyObjA::verifyAllredView, PlusOp, reduce::allreduce::RabenseifnerT>(