diff --git a/simd/src/Kokkos_SIMD_Common.hpp b/simd/src/Kokkos_SIMD_Common.hpp index e27413da74a..1489b6ce710 100644 --- a/simd/src/Kokkos_SIMD_Common.hpp +++ b/simd/src/Kokkos_SIMD_Common.hpp @@ -323,6 +323,89 @@ template return Kokkos::round(x); } +// fallback implementations of simd reductions: + +template > +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, BinaryOperation binary_op = {}) { + auto v = where(true, x); + return reduce(v, binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + T identity_element, BinaryOperation binary_op) { + if (none_of(mask)) return identity_element; + auto v = where(mask, x); + return reduce(v, binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + std::plus<> binary_op = {}) noexcept { + return reduce(x, mask, T(0), binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + std::multiplies<> binary_op) noexcept { + return reduce(x, mask, T(0), binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + std::bit_and<> binary_op) noexcept { + return reduce(x, mask, 0, binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + std::bit_or<> binary_op) noexcept { + return reduce(x, mask, 0, binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + const simd& x, const typename simd::mask_type& mask, + std::bit_xor<> binary_op) noexcept { + return reduce(x, mask, 0, binary_op); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_min( + const simd& x) noexcept { + auto v = where(true, x); + return hmin(v); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_min( + const simd& x, + const typename simd::mask_type& mask) noexcept { + auto v = where(mask, x); + return hmin(v); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_max( + const simd& x) noexcept { + auto v = where(true, x); + return hmax(v); +} + +template +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_max( + const simd& x, + const typename simd::mask_type& mask) noexcept { + auto v = where(mask, x); + return hmax(v); +} + } // namespace Experimental } // namespace Kokkos diff --git a/simd/src/Kokkos_SIMD_Common_Math.hpp b/simd/src/Kokkos_SIMD_Common_Math.hpp index 8c6a9559604..9e5991e6912 100644 --- a/simd/src/Kokkos_SIMD_Common_Math.hpp +++ b/simd/src/Kokkos_SIMD_Common_Math.hpp @@ -56,19 +56,26 @@ hmax(const_where_expression, simd> const& x) { return result; } -template +template > [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T -reduce(const_where_expression, simd> const& x, T, - std::plus<>) { +reduce(const_where_expression, simd> const& x, + BinaryOperation op = {}) { auto const& v = x.impl_get_value(); auto const& m = x.impl_get_mask(); - auto result = Kokkos::reduction_identity::sum(); - for (std::size_t i = 0; i < v.size(); ++i) { - if (m[i]) result += v[i]; + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { + if (m[i]) result = op(result, v[i]); } return result; } +template +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T +reduce(const_where_expression, simd> const& x, T, + std::plus<>) { + return reduce(x, std::plus<>()); +} + } // namespace Experimental template diff --git a/simd/unit_tests/include/SIMDTesting_Ops.hpp b/simd/unit_tests/include/SIMDTesting_Ops.hpp index 4b3f9968108..8de86beaed3 100644 --- a/simd/unit_tests/include/SIMDTesting_Ops.hpp +++ b/simd/unit_tests/include/SIMDTesting_Ops.hpp @@ -331,35 +331,38 @@ class log_op { class hmin { public: - template - auto on_host(T const& a) const { - return Kokkos::Experimental::hmin(a); - } - template - auto on_host_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::min(); - for (std::size_t i = 0; i < v.size(); ++i) { + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::hmin(w); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { if (m[i]) result = Kokkos::min(result, v[i]); } return result; } - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const { - return Kokkos::Experimental::hmin(a); - } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::min(); - for (std::size_t i = 0; i < v.size(); ++i) { + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::hmin(w); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { if (m[i]) result = Kokkos::min(result, v[i]); } return result; @@ -368,77 +371,147 @@ class hmin { class hmax { public: - template - auto on_host(T const& a) const { - return Kokkos::Experimental::hmax(a); + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::hmax(w); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { + if (m[i]) result = Kokkos::max(result, v[i]); + } + return result; } - template - auto on_host_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::max(); - for (std::size_t i = 0; i < v.size(); ++i) { + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::hmax(w); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask = true) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { if (m[i]) result = Kokkos::max(result, v[i]); } return result; } +}; - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const { - return Kokkos::Experimental::hmax(a); +template > +class reduce_where_expr { + public: + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::reduce(w, BinaryOperation()); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, MaskType mask) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { + if (m[i]) result = BinaryOperation()(result, v[i]); + } + return result; } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::max(); - for (std::size_t i = 0; i < v.size(); ++i) { - if (m[i]) result = Kokkos::max(result, v[i]); + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { + auto w = Kokkos::Experimental::where(mask, a); + return Kokkos::Experimental::reduce(w, BinaryOperation()); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask) const { + auto w = Kokkos::Experimental::where(mask, a); + auto const& v = w.impl_get_value(); + auto const& m = w.impl_get_mask(); + auto result = v[0]; + for (std::size_t i = 1; i < v.size(); ++i) { + if (m[i]) result = BinaryOperation()(result, v[i]); } return result; } }; -class reduce { +class reduce_min { public: - template - auto on_host(T const& a) const { - using DataType = typename T::value_type::value_type; - return Kokkos::Experimental::reduce(a, DataType(0), std::plus<>()); + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce_min(a, mask); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, MaskType mask) const { + return hmin().on_host_serial(a, mask); } - template - auto on_host_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::sum(); - for (std::size_t i = 0; i < v.size(); ++i) { - if (m[i]) result += v[i]; - } - return result; + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce_min(a, mask); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask) const { + return hmin().on_device_serial(a, mask); } +}; - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a) const { - using DataType = typename T::value_type::value_type; - return Kokkos::Experimental::reduce(a, DataType(0), std::plus<>()); +class reduce_max { + public: + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce_max(a, mask); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, MaskType mask) const { + return hmax().on_host_serial(a, mask); } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a) const { - using DataType = typename T::value_type::value_type; - auto const& v = a.impl_get_value(); - auto const& m = a.impl_get_mask(); - auto result = Kokkos::reduction_identity::sum(); - for (std::size_t i = 0; i < v.size(); ++i) { - if (m[i]) result += v[i]; - } - return result; + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce_max(a, mask); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask) const { + return hmax().on_device_serial(a, mask); + } +}; + +template > +class reduce { + public: + template + KOKKOS_INLINE_FUNCTION auto on_host(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce(a, mask, BinaryOperation()); + } + template + KOKKOS_INLINE_FUNCTION auto on_host_serial(T const& a, MaskType mask) const { + return reduce_where_expr().on_host_serial(a, mask); + } + + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { + return Kokkos::Experimental::reduce(a, mask, BinaryOperation()); + } + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + MaskType mask) const { + return reduce_where_expr().on_device_serial(a, mask); } }; diff --git a/simd/unit_tests/include/TestSIMD_Reductions.hpp b/simd/unit_tests/include/TestSIMD_Reductions.hpp index b1aef98c2a8..dde4735e823 100644 --- a/simd/unit_tests/include/TestSIMD_Reductions.hpp +++ b/simd/unit_tests/include/TestSIMD_Reductions.hpp @@ -39,9 +39,8 @@ inline void host_check_reduction_one_loader(ReductionOp reduce_op, for (std::size_t j = 0; j < n; ++j) { mask[j] = true; } - auto value = where(mask, arg); - auto expected = reduce_op.on_host_serial(value); - auto computed = reduce_op.on_host(value); + auto expected = reduce_op.on_host_serial(arg, mask); + auto computed = reduce_op.on_host(arg, mask); gtest_checker().equality(expected, computed); } @@ -60,7 +59,15 @@ template inline void host_check_all_reductions(const DataType (&args)[n]) { host_check_reduction_all_loaders(hmin(), n, args); host_check_reduction_all_loaders(hmax(), n, args); - host_check_reduction_all_loaders(reduce(), n, args); + host_check_reduction_all_loaders(reduce_where_expr>(), n, + args); + host_check_reduction_all_loaders(reduce_where_expr>(), + n, args); + + host_check_reduction_all_loaders(reduce_min(), n, args); + host_check_reduction_all_loaders(reduce_max(), n, args); + host_check_reduction_all_loaders(reduce>(), n, args); + host_check_reduction_all_loaders(reduce>(), n, args); } template @@ -108,9 +115,8 @@ KOKKOS_INLINE_FUNCTION void device_check_reduction_one_loader( for (std::size_t j = 0; j < n; ++j) { mask[j] = true; } - auto value = where(mask, arg); - auto expected = reduce_op.on_device_serial(value); - auto computed = reduce_op.on_device(value); + auto expected = reduce_op.on_device_serial(arg, mask); + auto computed = reduce_op.on_device(arg, mask); kokkos_checker().equality(expected, computed); } @@ -130,7 +136,15 @@ KOKKOS_INLINE_FUNCTION void device_check_all_reductions( const DataType (&args)[n]) { device_check_reduction_all_loaders(hmin(), n, args); device_check_reduction_all_loaders(hmax(), n, args); - device_check_reduction_all_loaders(reduce(), n, args); + device_check_reduction_all_loaders(reduce_where_expr>(), n, + args); + device_check_reduction_all_loaders( + reduce_where_expr>(), n, args); + + device_check_reduction_all_loaders(reduce_min(), n, args); + device_check_reduction_all_loaders(reduce_max(), n, args); + device_check_reduction_all_loaders(reduce>(), n, args); + device_check_reduction_all_loaders(reduce>(), n, args); } template