From 293f6f07791046016745bfa7f89b395246d96d78 Mon Sep 17 00:00:00 2001 From: Dong Hun Lee Date: Thu, 29 Feb 2024 20:11:13 -0700 Subject: [PATCH] Added hmin, hmax and reduce (sum) in AVX512 --- simd/src/Kokkos_SIMD_AVX512.hpp | 105 ++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/simd/src/Kokkos_SIMD_AVX512.hpp b/simd/src/Kokkos_SIMD_AVX512.hpp index 5c456126de1..82cd978c981 100644 --- a/simd/src/Kokkos_SIMD_AVX512.hpp +++ b/simd/src/Kokkos_SIMD_AVX512.hpp @@ -3225,6 +3225,73 @@ class where_expression>, _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); } +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t hmin( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_min_epi32( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t hmax( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_max_epu32( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t hmin( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_min_epu32( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t hmax( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_max_epi64(static_cast<__mmask8>(x.impl_get_mask()), + static_cast<__m512i>(x.impl_get_value())); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t hmin( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_min_epi64(static_cast<__mmask8>(x.impl_get_mask()), + static_cast<__m512i>(x.impl_get_value())); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t hmax( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_max_epu64(static_cast<__mmask8>(x.impl_get_mask()), + static_cast<__m512i>(x.impl_get_value())); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t hmin( + const_where_expression< + simd_mask>, + simd>> const& x) { + return _mm512_mask_reduce_min_epu64(static_cast<__mmask8>(x.impl_get_mask()), + static_cast<__m512i>(x.impl_get_value())); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmax( + const_where_expression>, + simd>> const& + x) { + return _mm512_mask_reduce_max_pd(static_cast<__mmask8>(x.impl_get_mask()), + static_cast<__m512d>(x.impl_get_value())); +} + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmin( const_where_expression>, simd>> const& @@ -3233,6 +3300,34 @@ class where_expression>, static_cast<__m512d>(x.impl_get_value())); } +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float hmax( + const_where_expression>, + simd>> const& + x) { + return _mm512_mask_reduce_max_ps( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float hmin( + const_where_expression>, + simd>> const& + x) { + return _mm512_mask_reduce_min_ps( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); +} + +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce( + const_where_expression< + simd_mask>, + simd>> const& x, + std::int32_t, std::plus<>) { + return _mm512_mask_reduce_add_epi32( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value()))); +} + [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce( const_where_expression< simd_mask>, @@ -3251,6 +3346,16 @@ class where_expression>, static_cast<__m512d>(x.impl_get_value())); } +[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce( + const_where_expression>, + simd>> const& + x, + float, std::plus<>) { + return _mm512_mask_reduce_add_ps( + static_cast<__mmask8>(x.impl_get_mask()), + _mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value()))); +} + } // namespace Experimental } // namespace Kokkos