diff --git a/simd/src/Kokkos_SIMD_AVX512.hpp b/simd/src/Kokkos_SIMD_AVX512.hpp index 89c297bb0fe..fb1b417822e 100644 --- a/simd/src/Kokkos_SIMD_AVX512.hpp +++ b/simd/src/Kokkos_SIMD_AVX512.hpp @@ -3237,19 +3237,25 @@ class where_expression>, hmax(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epi32( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce_max( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epi32( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); + static_cast<__mmask8>(m), + _mm512_castsi256_si512(static_cast<__m256i>(v))); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3258,35 +3264,47 @@ class where_expression>, hmin(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epi32( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce_min( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epi32( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); + static_cast<__mmask8>(m), + _mm512_castsi256_si512(static_cast<__m256i>(v))); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_max_epi32(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce_max(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_epi32(static_cast<__mmask16>(m), + static_cast<__m512i>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_min_epi32(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce_min(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_epi32(static_cast<__mmask16>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3295,19 +3313,25 @@ class where_expression>, hmax(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epu32( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +reduce_max(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epu32( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); + static_cast<__mmask8>(m), + _mm512_castsi256_si512(static_cast<__m256i>(v))); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3316,35 +3340,47 @@ class where_expression>, hmin(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epu32( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +reduce_min(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epu32( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); + static_cast<__mmask8>(m), + _mm512_castsi256_si512(static_cast<__m256i>(v))); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_max_epu32(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +reduce_max(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_epu32(static_cast<__mmask16>(m), + static_cast<__m512i>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_min_epu32(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +reduce_min(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_epu32(static_cast<__mmask16>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3353,17 +3389,23 @@ class where_expression>, hmax(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epi64(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512i>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_max_epi64(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int64_t +reduce_max( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_epi64(static_cast<__mmask8>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3372,17 +3414,23 @@ class where_expression>, hmin(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epi64(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512i>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_min_epi64(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int64_t +reduce_min( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_epi64(static_cast<__mmask8>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3391,17 +3439,23 @@ class where_expression>, hmax(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_epu64(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512i>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_max_epu64(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint64_t +reduce_max(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_epu64(static_cast<__mmask8>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3410,17 +3464,23 @@ class where_expression>, hmin(const_where_expression< simd_mask>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_epu64(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512i>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return _mm512_mask_reduce_min_epu64(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint64_t +reduce_min(simd> const& v, + simd_mask> const& + m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_epu64(static_cast<__mmask8>(m), + static_cast<__m512i>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3428,17 +3488,22 @@ class where_expression>, hmax(const_where_expression>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_pd(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512d>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce_max( - const_where_expression>, - simd>> const& - x) { - return _mm512_mask_reduce_max_pd(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512d>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double reduce_max( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_pd(static_cast<__mmask8>(m), + static_cast<__m512d>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3446,17 +3511,22 @@ hmax(const_where_expression>, hmin(const_where_expression>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(x.impl_get_mask()), static_cast<__m512d>(x.impl_get_value())); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce_min( - const_where_expression>, - simd>> const& - x) { - return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512d>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double reduce_min( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_pd(static_cast<__mmask8>(m), + static_cast<__m512d>(v)); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3464,19 +3534,23 @@ hmin(const_where_expression>, hmax(const_where_expression>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_ps( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_max( - const_where_expression>, - simd>> const& - x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce_max( + simd> const& v, + simd_mask> m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } return _mm512_mask_reduce_max_ps( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); + static_cast<__mmask8>(m), _mm512_castps256_ps512(static_cast<__m256>(v))); } #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 @@ -3484,91 +3558,110 @@ hmax(const_where_expression>, hmin(const_where_expression>, simd>> const& x) { + if (none_of(x.impl_get_mask())) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_ps( static_cast<__mmask8>(x.impl_get_mask()), _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); } #endif -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_min( - const_where_expression>, - simd>> const& - x) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce_min( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } return _mm512_mask_reduce_min_ps( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); + static_cast<__mmask8>(m), _mm512_castps256_ps512(static_cast<__m256>(v))); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_max( - const_where_expression>, - simd>> const& - x) { - return _mm512_mask_reduce_max_ps(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce_max( + simd> const& v, + simd_mask> m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::max(); + } + return _mm512_mask_reduce_max_ps(static_cast<__mmask16>(m), + static_cast<__m512>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_min( - const_where_expression>, - simd>> const& - x) { - return _mm512_mask_reduce_min_ps(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce_min( + simd> const& v, + simd_mask> const& m) noexcept { + if (none_of(m)) { + return Kokkos::reduction_identity::min(); + } + return _mm512_mask_reduce_min_ps(static_cast<__mmask16>(m), + static_cast<__m512>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int32_t, std::plus<>) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce(simd> const& v, + simd_mask> const& m, + std::int32_t identity, std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } return _mm512_mask_reduce_add_epi32( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); + static_cast<__mmask8>(m), + _mm512_castsi256_si512(static_cast<__m256i>(v))); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int32_t, std::plus<>) { - return _mm512_mask_reduce_add_epi32(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +reduce(simd> const& v, + simd_mask> const& m, + std::int32_t identity, std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } + return _mm512_mask_reduce_add_epi32(static_cast<__mmask16>(m), + static_cast<__m512i>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int64_t, std::plus<>) { - return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512i>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int64_t +reduce(simd> const& v, + simd_mask> const& m, + std::int64_t identity, std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } + return _mm512_mask_reduce_add_epi64(static_cast<__mmask8>(m), + static_cast<__m512i>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce( - const_where_expression>, - simd>> const& - x, - double, std::plus<>) { - return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(x.impl_get_mask()), - static_cast<__m512d>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double reduce( + simd> const& v, + simd_mask> const& m, double identity, + std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } + return _mm512_mask_reduce_add_pd(static_cast<__mmask8>(m), + static_cast<__m512d>(v)); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce( - const_where_expression>, - simd>> const& - x, - float, std::plus<>) { +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce( + simd> const& v, + simd_mask> const& m, float identity, + std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } return _mm512_mask_reduce_add_ps( - static_cast<__mmask8>(x.impl_get_mask()), - _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); + static_cast<__mmask8>(m), _mm512_castps256_ps512(static_cast<__m256>(v))); } -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce( - const_where_expression>, - simd>> const& - x, - float, std::plus<>) { - return _mm512_mask_reduce_add_ps(static_cast<__mmask16>(x.impl_get_mask()), - static_cast<__m512>(x.impl_get_value())); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce( + simd> const& v, + simd_mask> const& m, float identity, + std::plus<>) noexcept { + if (none_of(m)) { + return identity; + } + return _mm512_mask_reduce_add_ps(static_cast<__mmask16>(m), + static_cast<__m512>(v)); } } // namespace Experimental diff --git a/simd/src/Kokkos_SIMD_Common.hpp b/simd/src/Kokkos_SIMD_Common.hpp index 70ba0b00d65..cc055142628 100644 --- a/simd/src/Kokkos_SIMD_Common.hpp +++ b/simd/src/Kokkos_SIMD_Common.hpp @@ -317,19 +317,19 @@ KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd& operator<<=( // fallback implementations of reductions across simd_mask: template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool all_of( +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool all_of( simd_mask const& a) { return a == simd_mask(true); } template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool any_of( +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool any_of( simd_mask const& a) { return a != simd_mask(false); } template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool none_of( +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool none_of( simd_mask const& a) { return a == simd_mask(false); } @@ -352,93 +352,88 @@ template template > [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( const simd& x, BinaryOperation binary_op = {}) { - auto v = where(true, x); - return reduce(v, binary_op); -} - -template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - T identity_element, BinaryOperation binary_op) { - if (none_of(mask)) return identity_element; - auto v = where(mask, x); - return reduce(v, binary_op); -} - -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - std::plus<> binary_op = {}) noexcept { - return reduce(x, mask, T(0), binary_op); -} - -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - std::multiplies<> binary_op) noexcept { - return reduce(x, mask, T(0), binary_op); -} - -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - std::bit_and<> binary_op) noexcept { - return reduce(x, mask, 0, binary_op); -} - -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - std::bit_or<> binary_op) noexcept { - return reduce(x, mask, 0, binary_op); -} - -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( - const simd& x, const typename simd::mask_type& mask, - std::bit_xor<> binary_op) noexcept { - return reduce(x, mask, 0, binary_op); -} + return reduce(x, simd::mask_type(true), T(0), binary_op); +} + +// template +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// T identity_element, BinaryOperation binary_op) { +// if (none_of(mask)) return identity_element; +// return reduce(x, mask, identity_element, binary_op); +// } + +// template < +// class T, class Abi, +// std::enable_if_t, bool> = false> +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// std::plus<> binary_op = {}) noexcept { +// return reduce(x, mask, T(0), binary_op); +// } + +// template < +// class T, class Abi, +// std::enable_if_t, bool> = false> +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// std::multiplies<> binary_op) noexcept { +// return reduce(x, mask, T(0), binary_op); +// } + +// template < +// class T, class Abi, +// std::enable_if_t, bool> = false> +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// std::bit_and<> binary_op) noexcept { +// return reduce(x, mask, 0, binary_op); +// } + +// template < +// class T, class Abi, +// std::enable_if_t, bool> = false> +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// std::bit_or<> binary_op) noexcept { +// return reduce(x, mask, 0, binary_op); +// } + +// template < +// class T, class Abi, +// std::enable_if_t, bool> = false> +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( +// const simd& x, const typename simd::mask_type& mask, +// std::bit_xor<> binary_op) noexcept { +// return reduce(x, mask, 0, binary_op); +// } template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min( const simd& x) noexcept { - auto v = where(true, x); - return reduce_min(v); + return reduce_min(x, simd::mask_type(true)); } -template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min( - const simd& x, - const typename simd::mask_type& mask) noexcept { - auto v = where(mask, x); - return reduce_min(v); -} +// template +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min( +// const simd& x, +// const typename simd::mask_type& mask) noexcept { +// return reduce_min(x, mask); +// } template [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max( const simd& x) noexcept { auto v = where(true, x); - return reduce_max(v); + return reduce_max(x, simd::mask_type(true)); } -template -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max( - const simd& x, - const typename simd::mask_type& mask) noexcept { - auto v = where(mask, x); - return reduce_max(v); -} +// template +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max( +// const simd& x, +// const typename simd::mask_type& mask) noexcept { +// return reduce_max(x, mask); +// } } // namespace Experimental } // namespace Kokkos diff --git a/simd/src/Kokkos_SIMD_Common_Math.hpp b/simd/src/Kokkos_SIMD_Common_Math.hpp index 44e0531053c..c8851892205 100644 --- a/simd/src/Kokkos_SIMD_Common_Math.hpp +++ b/simd/src/Kokkos_SIMD_Common_Math.hpp @@ -65,11 +65,9 @@ hmax(const_where_expression, simd> const& x) { template < typename T, typename Abi, std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T -reduce_min(const_where_expression, simd> const& x) { - auto const& v = x.impl_get_value(); - auto const& m = x.impl_get_mask(); - auto result = Kokkos::reduction_identity::min(); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min( + simd const& v, typename simd::mask_type const& m) { + auto result = Kokkos::reduction_identity::min(); for (std::size_t i = 0; i < v.size(); ++i) { if (m[i]) result = Kokkos::min(result, v[i]); } @@ -79,11 +77,9 @@ reduce_min(const_where_expression, simd> const& x) { template < class T, class Abi, std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T -reduce_max(const_where_expression, simd> const& x) { - auto const& v = x.impl_get_value(); - auto const& m = x.impl_get_mask(); - auto result = Kokkos::reduction_identity::max(); +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max( + simd const& v, typename simd::mask_type const& m) { + auto result = Kokkos::reduction_identity::max(); for (std::size_t i = 0; i < v.size(); ++i) { if (m[i]) result = Kokkos::max(result, v[i]); } @@ -93,27 +89,19 @@ reduce_max(const_where_expression, simd> const& x) { template < class T, class Abi, class BinaryOperation = std::plus<>, std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T -reduce(const_where_expression, simd> const& x, - BinaryOperation op = {}) { - auto const& v = x.impl_get_value(); - auto const& m = x.impl_get_mask(); - auto result = v[0]; +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce( + simd const& v, typename simd::mask_type const& m, + T identity, BinaryOperation op = {}) { + if (none_of(m)) { + return identity; + } + auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { if (m[i]) result = op(result, v[i]); } return result; } -template < - class T, class Abi, - std::enable_if_t, bool> = false> -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T -reduce(const_where_expression, simd> const& x, T, - std::plus<>) { - return reduce(x, std::plus<>()); -} - } // namespace Experimental template diff --git a/simd/src/Kokkos_SIMD_NEON.hpp b/simd/src/Kokkos_SIMD_NEON.hpp index f35bd04bb8d..51c81c5d1bf 100644 --- a/simd/src/Kokkos_SIMD_NEON.hpp +++ b/simd/src/Kokkos_SIMD_NEON.hpp @@ -2722,172 +2722,195 @@ class where_expression>, } }; -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vminv_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vmaxv_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int32_t, std::plus<>) { - return vaddv_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vminvq_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vmaxvq_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int32_t, std::plus<>) { - return vaddvq_s32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vminv_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vmaxv_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::uint32_t, std::plus<>) { - return vaddv_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_min( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vminvq_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce_max( - const_where_expression< - simd_mask>, - simd>> const& x) { - return vmaxvq_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::uint32_t, std::plus<>) { - return vaddvq_u32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::int64_t, std::plus<>) { - return vaddvq_s64(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t reduce( - const_where_expression< - simd_mask>, - simd>> const& x, - std::uint64_t, std::plus<>) { - return vaddvq_u64(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce_min( - const_where_expression>, - simd>> const& - x) { - return vminvq_f64(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce_max( - const_where_expression>, - simd>> const& - x) { - return vmaxvq_f64(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double reduce( - const_where_expression>, - simd>> const& x, - double, std::plus<>) { - return vaddvq_f64(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_min( - const_where_expression>, - simd>> const& - x) { - return vminv_f32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_max( - const_where_expression>, - simd>> const& - x) { - return vmaxv_f32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce( - const_where_expression>, - simd>> const& x, - float, std::plus<>) { - return vaddv_f32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_min( - const_where_expression>, - simd>> const& - x) { - return vminvq_f32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce_max( - const_where_expression>, - simd>> const& - x) { - return vmaxvq_f32(static_cast(x.impl_get_value())); -} - -[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce( - const_where_expression>, - simd>> const& x, - float, std::plus<>) { - return vaddvq_f32(static_cast(x.impl_get_value())); -} +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce_min( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vminv_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce_max( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vmaxv_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::int32_t identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddv_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce_min( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vminvq_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce_max( +// simd> const& v, +// simd_mask> m) noexcept { +// return vmaxvq_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int32_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::int32_t identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddvq_s32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce_min( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vminv_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce_max( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vmaxv_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::uint32_t identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddv_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce_min( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vminvq_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce_max( +// simd> const& v, +// simd_mask> const& m) +// noexcept { +// return vmaxvq_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint32_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::uint32_t identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddvq_u32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::int64_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::int64_t identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddvq_s64(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr std::uint64_t +// reduce( +// simd> const& v, +// simd_mask> const& m, +// std::uint64_t identity, std::plus<>) noexcept { +// return vaddvq_u64(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double +// reduce_min( +// simd < double, simd_abi::neon_fixed_size < 2 >> const& v, +// simd_mask> const& m) noexcept { +// return vminvq_f64(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double +// reduce_max( +// simd> const& +// v, +// simd_mask> const& m) noexcept { +// return vmaxvq_f64(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr double reduce( +// simd> const& +// v, +// simd_mask> const& m, +// double identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddvq_f64(static_cast(x.impl_get_value())); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float +// reduce_min( +// simd> const& +// v, +// simd_mask> const& m) noexcept { +// return vminv_f32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float +// reduce_max( +// simd> const& +// v, +// simd_mask> const& m) noexcept { +// return vmaxv_f32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce( +// simd> const& +// v, +// simd_mask> const& m, +// float identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddv_f32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float +// reduce_min( +// simd> const& +// v, +// simd_mask> const& m) noexcept { +// return vminvq_f32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float +// reduce_max( +// simd> const& +// const& v +// simd_mask> const& m) noexcept { +// return vmaxvq_f32(static_cast(v)); +// } + +// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr float reduce( +// simd> const& +// v, +// simd_mask> const& m, +// float identity, std::plus<>) noexcept { +// if (none_of(m)) return identity; +// return vaddvq_f32(static_cast(v)); +// } } // namespace Experimental } // namespace Kokkos diff --git a/simd/src/Kokkos_SIMD_Scalar.hpp b/simd/src/Kokkos_SIMD_Scalar.hpp index 7ac23d1a19b..780cda14a89 100644 --- a/simd/src/Kokkos_SIMD_Scalar.hpp +++ b/simd/src/Kokkos_SIMD_Scalar.hpp @@ -348,51 +348,55 @@ KOKKOS_FORCEINLINE_FUNCTION simd condition( } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce(Experimental::simd const& x, - BinaryOperation binary_op) { - auto v = where(true, x); - return reduce(v, T(0), binary_op); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + Experimental::simd const& x, + Experimental::simd_mask const& mask, + T identity, BinaryOperation) noexcept { + if (!mask) return identity; + return x[0]; } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce(Experimental::simd const& x, - Experimental::simd_mask const& mask, - BinaryOperation binary_op) { - if (!mask) return T(0); - auto v = where(mask, x); - return reduce(v, T(0), binary_op); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce( + Experimental::simd const& x, + BinaryOperation binary_op) noexcept { + return reduce( + x, Experimental::simd::mask_type(true), + T(0), binary_op); } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce_min(Experimental::simd const& x) { - auto v = where(true, x); - return reduce_min(v); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_min( + Experimental::simd const& x, + Experimental::simd_mask const& + mask) noexcept { + if (!mask) return Kokkos::reduction_identity::min(); + return x[0]; } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T reduce_min( - Experimental::simd const& x, - Experimental::simd_mask const& mask) { - auto v = where(mask, x); - return reduce_min(v); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_min( + Experimental::simd const& x) noexcept { + return reduce_min( + x, + Experimental::simd::mask_type(true)); } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce_max(Experimental::simd const& x) { - auto v = where(true, x); - return reduce_max(v); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_max( + Experimental::simd const& x, + Experimental::simd_mask const& + mask) noexcept { + if (!mask) return Kokkos::reduction_identity::max(); + return x[0]; } template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T reduce_max( - Experimental::simd const& x, - Experimental::simd_mask const& mask) { - auto v = where(mask, x); - return reduce_max(v); +[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION constexpr T reduce_max( + Experimental::simd const& x) noexcept { + return reduce_max( + x, + Experimental::simd::mask_type(true)); } template @@ -514,24 +518,6 @@ template return a == simd_mask(false); } -template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce(const_where_expression, - simd> const& x, - T identity_element, BinaryOperation) { - return static_cast(x.impl_get_mask()) - ? static_cast(x.impl_get_value()) - : identity_element; -} - -template -[[nodiscard]] KOKKOS_FORCEINLINE_FUNCTION T -reduce(const_where_expression, - simd> const& x, - BinaryOperation op) { - return reduce(x, T(0), op); -} - #ifdef KOKKOS_ENABLE_DEPRECATED_CODE_4 template [[nodiscard]] KOKKOS_DEPRECATED KOKKOS_FORCEINLINE_FUNCTION T diff --git a/simd/unit_tests/include/SIMDTesting_Ops.hpp b/simd/unit_tests/include/SIMDTesting_Ops.hpp index e107c4b053c..04409e201ba 100644 --- a/simd/unit_tests/include/SIMDTesting_Ops.hpp +++ b/simd/unit_tests/include/SIMDTesting_Ops.hpp @@ -331,150 +331,136 @@ class log_op { } }; -template > -class reduce_where_expr { +class reduce_min { public: - template - auto on_host(T const& a, MaskType mask) const { - auto w = Kokkos::Experimental::where(mask, a); - return Kokkos::Experimental::reduce(w, BinaryOperation()); + template + auto on_host(T const& a, U, MaskType mask) const { + return Kokkos::Experimental::reduce_min(a, mask); } - template - auto on_host_serial(T const& a, MaskType mask) const { + template + auto on_host_serial(T const& a, U, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) + return Kokkos::reduction_identity::min(); auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if (m[i]) result = BinaryOperation()(result, v[i]); + if (m[i]) result = Kokkos::min(result, v[i]); } return result; } - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { - auto w = Kokkos::Experimental::where(mask, a); - return Kokkos::Experimental::reduce(w, BinaryOperation()); + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, U, MaskType mask) const { + return Kokkos::Experimental::reduce_min(a, mask); } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, U, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) + return Kokkos::reduction_identity::min(); auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if constexpr (std::is_same_v>) { - if (m[i]) result = result + v[i]; - } else if constexpr (std::is_same_v>) { - if (m[i]) result = result * v[i]; - } else if constexpr (std::is_same_v>) { - if (m[i]) result = result & v[i]; - } else if constexpr (std::is_same_v>) { - if (m[i]) result = result | v[i]; - } else if constexpr (std::is_same_v>) { - if (m[i]) result = result ^ v[i]; - } else { - Kokkos::abort("Unsupported reduce operation"); - } + if (m[i]) result = Kokkos::min(result, v[i]); } return result; } }; -class reduce_min { +class reduce_max { public: - template - auto on_host(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce_min(a, mask); + template + auto on_host(T const& a, U, MaskType mask) const { + return Kokkos::Experimental::reduce_max(a, mask); } - template - auto on_host_serial(T const& a, MaskType mask) const { + template + auto on_host_serial(T const& a, U, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) + return Kokkos::reduction_identity::max(); auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if (m[i]) result = Kokkos::min(result, v[i]); + if (m[i]) result = Kokkos::max(result, v[i]); } return result; } - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce_min(a, mask); + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, U, MaskType mask) const { + return Kokkos::Experimental::reduce_max(a, mask); } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, U, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) + return Kokkos::reduction_identity::max(); auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if (m[i]) result = Kokkos::min(result, v[i]); + if (m[i]) result = Kokkos::max(result, v[i]); } return result; } }; -class reduce_max { +template > +class reduce { public: - template - auto on_host(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce_max(a, mask); + template + auto on_host(T const& a, U const& identity, MaskType mask) const { + return Kokkos::Experimental::reduce(a, mask, identity, BinaryOperation()); } - template - auto on_host_serial(T const& a, MaskType mask) const { + template + auto on_host_serial(T const& a, U const& identity, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) return identity; auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if (m[i]) result = Kokkos::max(result, v[i]); + if (m[i]) result = BinaryOperation()(result, v[i]); } return result; } - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce_max(a, mask); + template + KOKKOS_INLINE_FUNCTION auto on_device(T const& a, U const& identity, + MaskType mask) const { + return Kokkos::Experimental::reduce(a, mask, identity, BinaryOperation()); } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, + template + KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, U const& identity, MaskType mask) const { + if (Kokkos::Experimental::none_of(mask)) return identity; auto w = Kokkos::Experimental::where(mask, a); auto const& v = w.impl_get_value(); auto const& m = w.impl_get_mask(); auto result = v[0]; for (std::size_t i = 1; i < v.size(); ++i) { - if (m[i]) result = Kokkos::max(result, v[i]); + if constexpr (std::is_same_v>) { + if (m[i]) result = result + v[i]; + } else if constexpr (std::is_same_v>) { + if (m[i]) result = result * v[i]; + } else if constexpr (std::is_same_v>) { + if (m[i]) result = result & v[i]; + } else if constexpr (std::is_same_v>) { + if (m[i]) result = result | v[i]; + } else if constexpr (std::is_same_v>) { + if (m[i]) result = result ^ v[i]; + } else { + Kokkos::abort("Unsupported reduce operation"); + } } return result; } }; -template > -class reduce { - public: - template - auto on_host(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce(a, mask, BinaryOperation()); - } - template - auto on_host_serial(T const& a, MaskType mask) const { - return reduce_where_expr().on_host_serial(a, mask); - } - - template - KOKKOS_INLINE_FUNCTION auto on_device(T const& a, MaskType mask) const { - return Kokkos::Experimental::reduce(a, mask, BinaryOperation()); - } - template - KOKKOS_INLINE_FUNCTION auto on_device_serial(T const& a, - MaskType mask) const { - return reduce_where_expr().on_device_serial(a, mask); - } -}; - #endif diff --git a/simd/unit_tests/include/TestSIMD_Reductions.hpp b/simd/unit_tests/include/TestSIMD_Reductions.hpp index baf1e7c88cc..e52fa7f8422 100644 --- a/simd/unit_tests/include/TestSIMD_Reductions.hpp +++ b/simd/unit_tests/include/TestSIMD_Reductions.hpp @@ -36,11 +36,16 @@ inline void host_check_reduction_one_loader(ReductionOp reduce_op, if (!loaded_arg) continue; mask_type mask(false); + T identity = 12; + auto expected = reduce_op.on_host_serial(arg, identity, mask); + auto computed = reduce_op.on_host(arg, identity, mask); + gtest_checker().equality(expected, computed); + for (std::size_t j = 0; j < n; ++j) { mask[j] = true; } - auto expected = reduce_op.on_host_serial(arg, mask); - auto computed = reduce_op.on_host(arg, mask); + expected = reduce_op.on_host_serial(arg, identity, mask); + computed = reduce_op.on_host(arg, identity, mask); gtest_checker().equality(expected, computed); } @@ -57,11 +62,6 @@ inline void host_check_reduction_all_loaders(ReductionOp reduce_op, template inline void host_check_all_reductions(const DataType (&args)[n]) { - host_check_reduction_all_loaders(reduce_where_expr>(), n, - args); - host_check_reduction_all_loaders(reduce_where_expr>(), - n, args); - host_check_reduction_all_loaders(reduce_min(), n, args); host_check_reduction_all_loaders(reduce_max(), n, args); host_check_reduction_all_loaders(reduce>(), n, args); @@ -114,11 +114,16 @@ KOKKOS_INLINE_FUNCTION void device_check_reduction_one_loader( if (!loaded_arg) continue; mask_type mask(false); + T identity = 12; + auto expected = reduce_op.on_device_serial(arg, identity, mask); + auto computed = reduce_op.on_device(arg, identity, mask); + kokkos_checker().equality(expected, computed); + for (std::size_t j = 0; j < n; ++j) { mask[j] = true; } - auto expected = reduce_op.on_device_serial(arg, mask); - auto computed = reduce_op.on_device(arg, mask); + expected = reduce_op.on_device_serial(arg, identity, mask); + computed = reduce_op.on_device(arg, identity, mask); kokkos_checker().equality(expected, computed); } @@ -136,11 +141,6 @@ KOKKOS_INLINE_FUNCTION void device_check_reduction_all_loaders( template KOKKOS_INLINE_FUNCTION void device_check_all_reductions( const DataType (&args)[n]) { - device_check_reduction_all_loaders(reduce_where_expr>(), n, - args); - device_check_reduction_all_loaders( - reduce_where_expr>(), n, args); - device_check_reduction_all_loaders(reduce_min(), n, args); device_check_reduction_all_loaders(reduce_max(), n, args); device_check_reduction_all_loaders(reduce>(), n, args);