diff --git a/core/src/Cuda/Kokkos_Cuda_Team.hpp b/core/src/Cuda/Kokkos_Cuda_Team.hpp index c2b5f1fa789..ea6430ee013 100644 --- a/core/src/Cuda/Kokkos_Cuda_Team.hpp +++ b/core/src/Cuda/Kokkos_Cuda_Team.hpp @@ -189,19 +189,25 @@ class CudaTeamMember { team_reduce(reducer, reducer.reference()); } - template - KOKKOS_INLINE_FUNCTION std::enable_if_t::value> - team_reduce(ReducerType const& reducer, - typename ReducerType::value_type& value) const noexcept { + template + KOKKOS_INLINE_FUNCTION void team_reduce(ReducerType const& reducer, + ValueType& value) const noexcept { (void)reducer; (void)value; + KOKKOS_IF_ON_DEVICE( - (typename Impl::FunctorAnalysis< + (using functor_analysis_type = typename Impl::FunctorAnalysis< Impl::FunctorPatternInterface::REDUCE, TeamPolicy, - ReducerType, typename ReducerType::value_type>::Reducer - wrapped_reducer(reducer); + ReducerType, ValueType>; + + constexpr bool is_reducer_functor = + functor_analysis_type::has_join_member_function && + functor_analysis_type::has_init_member_function && + !is_reducer_v; + typename functor_analysis_type::Reducer wrapped_reducer(reducer); + cuda_intra_block_reduction(value, wrapped_reducer, blockDim.y); - reducer.reference() = value;)) + if constexpr (!is_reducer_functor) { reducer.reference() = value; })) } //-------------------------------------------------------------------------- @@ -265,39 +271,48 @@ class CudaTeamMember { vector_reduce(reducer, reducer.reference()); } - template - KOKKOS_INLINE_FUNCTION static std::enable_if_t::value> - vector_reduce(ReducerType const& reducer, - typename ReducerType::value_type& value) { + template + KOKKOS_INLINE_FUNCTION static void vector_reduce(ReducerType const& reducer, + ValueType& value) { (void)reducer; (void)value; - KOKKOS_IF_ON_DEVICE( - (if (blockDim.x == 1) return; - - // Intra vector lane shuffle reduction: - typename ReducerType::value_type tmp(value); - typename ReducerType::value_type tmp2 = tmp; - - unsigned mask = - blockDim.x == 32 - ? 0xffffffff - : ((1 << blockDim.x) - 1) - << ((threadIdx.y % (32 / blockDim.x)) * blockDim.x); - - for (int i = blockDim.x; (i >>= 1);) { - Impl::in_place_shfl_down(tmp2, tmp, i, blockDim.x, mask); - if ((int)threadIdx.x < i) { - reducer.join(tmp, tmp2); - } - } - - // Broadcast from root lane to all other lanes. - // Cannot use "butterfly" algorithm to avoid the broadcast - // because floating point summation is not associative - // and thus different threads could have different results. - - Impl::in_place_shfl(tmp2, tmp, 0, blockDim.x, mask); - value = tmp2; reducer.reference() = tmp2;)) + KOKKOS_IF_ON_DEVICE(( + if (blockDim.x == 1) return; + + // Intra vector lane shuffle reduction: + typename ReducerType::value_type tmp(value); + typename ReducerType::value_type tmp2 = tmp; + + using functor_analysis_type = typename Impl::FunctorAnalysis< + Impl::FunctorPatternInterface::REDUCE, TeamPolicy, + ReducerType, ValueType>; + + constexpr bool is_reducer_functor = + functor_analysis_type::has_join_member_function && + functor_analysis_type::has_init_member_function && + !is_reducer_v; + + unsigned mask = + blockDim.x == 32 + ? 0xffffffff + : ((1 << blockDim.x) - 1) + << ((threadIdx.y % (32 / blockDim.x)) * blockDim.x); + + for (int i = blockDim.x; (i >>= 1);) { + Impl::in_place_shfl_down(tmp2, tmp, i, blockDim.x, mask); + if ((int)threadIdx.x < i) { + reducer.join(tmp, tmp2); + } + } + + // Broadcast from root lane to all other lanes. + // Cannot use "butterfly" algorithm to avoid the broadcast + // because floating point summation is not associative + // and thus different threads could have different results. + + Impl::in_place_shfl(tmp2, tmp, 0, blockDim.x, mask); + value = tmp2; + if constexpr (!is_reducer_functor) { reducer.reference() = tmp2; })) } //---------------------------------------- @@ -518,16 +533,42 @@ parallel_reduce(const Impl::TeamThreadRangeBoundariesStruct< (void)loop_boundaries; (void)closure; (void)result; - KOKKOS_IF_ON_DEVICE( - (ValueType val; Kokkos::Sum reducer(val); - reducer.init(reducer.reference()); + using functor_analysis_type = typename Impl::FunctorAnalysis< + Impl::FunctorPatternInterface::REDUCE, + TeamPolicy, Closure, + ValueType>; - for (iType i = loop_boundaries.start + threadIdx.y; - i < loop_boundaries.end; i += blockDim.y) { closure(i, val); } + constexpr bool is_reducer_closure = + functor_analysis_type::has_join_member_function && + functor_analysis_type::has_init_member_function; + + using ReducerSelector = + typename Kokkos::Impl::if_c>::type; - loop_boundaries.member.team_reduce(reducer, val); - result = reducer.reference();)) + KOKKOS_IF_ON_DEVICE(( + auto run_closure = + [&](ValueType& value) { + for (iType i = loop_boundaries.start + threadIdx.y; + i < loop_boundaries.end; i += blockDim.y) { + closure(i, value); + } + }; + ValueType val; + + if constexpr (is_reducer_closure) { + closure.init(val); + run_closure(val); + loop_boundaries.member.team_reduce(closure, val); + result = val; + } else { + ReducerSelector reducer(val); + reducer.init(reducer.reference()); + run_closure(val); + loop_boundaries.member.team_reduce(reducer); + result = reducer.reference(); + })) } template @@ -573,18 +614,45 @@ parallel_reduce(const Impl::TeamVectorRangeBoundariesStruct< (void)loop_boundaries; (void)closure; (void)result; - KOKKOS_IF_ON_DEVICE((ValueType val; Kokkos::Sum reducer(val); - reducer.init(reducer.reference()); + using functor_analysis_type = typename Impl::FunctorAnalysis< + Impl::FunctorPatternInterface::REDUCE, + TeamPolicy, Closure, + ValueType>; - for (iType i = loop_boundaries.start + - threadIdx.y * blockDim.x + threadIdx.x; - i < loop_boundaries.end; - i += blockDim.y * blockDim.x) { closure(i, val); } + constexpr bool is_reducer_closure = + functor_analysis_type::has_join_member_function && + functor_analysis_type::has_init_member_function; + + using ReducerSelector = + typename Kokkos::Impl::if_c>::type; - loop_boundaries.member.vector_reduce(reducer); - loop_boundaries.member.team_reduce(reducer); - result = reducer.reference();)) + KOKKOS_IF_ON_DEVICE(( + auto run_closure = + [&](ValueType& value) { + for (iType i = loop_boundaries.start + threadIdx.y * blockDim.x + + threadIdx.x; + i < loop_boundaries.end; i += blockDim.y * blockDim.x) { + closure(i, value); + } + }; + ValueType val; + + if constexpr (is_reducer_closure) { + closure.init(val); + run_closure(val); + loop_boundaries.member.vector_reduce(closure, val); + loop_boundaries.member.team_reduce(closure, val); + result = val; + } else { + ReducerSelector reducer(val); + reducer.init(reducer.reference()); + run_closure(val); + loop_boundaries.member.vector_reduce(reducer); + loop_boundaries.member.team_reduce(reducer); + result = reducer.reference(); + })) } //---------------------------------------------------------------------------- @@ -667,15 +735,42 @@ parallel_reduce(Impl::ThreadVectorRangeBoundariesStruct< (void)loop_boundaries; (void)closure; (void)result; - KOKKOS_IF_ON_DEVICE( - (result = ValueType(); - for (iType i = loop_boundaries.start + threadIdx.x; - i < loop_boundaries.end; i += blockDim.x) { closure(i, result); } + using functor_analysis_type = typename Impl::FunctorAnalysis< + Impl::FunctorPatternInterface::REDUCE, + TeamPolicy, Closure, + ValueType>; + + constexpr bool is_reducer_closure = + functor_analysis_type::has_join_member_function && + functor_analysis_type::has_init_member_function; - Impl::CudaTeamMember::vector_reduce(Kokkos::Sum(result)); + using ReducerSelector = + typename Kokkos::Impl::if_c>::type; - )) + KOKKOS_IF_ON_DEVICE(( + auto run_closure = + [&](ValueType& value) { + for (iType i = loop_boundaries.start + threadIdx.x; + i < loop_boundaries.end; i += blockDim.x) { + closure(i, value); + } + }; + ValueType val; + + if constexpr (is_reducer_closure) { + closure.init(val); + run_closure(val); + Impl::CudaTeamMember::vector_reduce(closure, val); + result = val; + } else { + ReducerSelector reducer(val); + reducer.init(val); + run_closure(val); + Impl::CudaTeamMember::vector_reduce(reducer); + result = reducer.reference(); + })) } //----------------------------------------------------------------------------