Skip to content

Commit

Permalink
Added hmin, hmax and reduce (sum) in AVX512
Browse files Browse the repository at this point in the history
  • Loading branch information
ldh4 committed Sep 17, 2024
1 parent 8746751 commit 293f6f0
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions simd/src/Kokkos_SIMD_AVX512.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,73 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t hmin(
const_where_expression<
simd_mask<std::int32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_min_epi32(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t hmax(
const_where_expression<
simd_mask<std::uint32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::uint32_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_max_epu32(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint32_t hmin(
const_where_expression<
simd_mask<std::uint32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::uint32_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_min_epu32(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t hmax(
const_where_expression<
simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int64_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_max_epi64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t hmin(
const_where_expression<
simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int64_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_min_epi64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t hmax(
const_where_expression<
simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::uint64_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_max_epu64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::uint64_t hmin(
const_where_expression<
simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
simd<std::uint64_t, simd_abi::avx512_fixed_size<8>>> const& x) {
return _mm512_mask_reduce_min_epu64(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512i>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmax(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
x) {
return _mm512_mask_reduce_max_pd(static_cast<__mmask8>(x.impl_get_mask()),
static_cast<__m512d>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION double hmin(
const_where_expression<simd_mask<double, simd_abi::avx512_fixed_size<8>>,
simd<double, simd_abi::avx512_fixed_size<8>>> const&
Expand All @@ -3233,6 +3300,34 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
static_cast<__m512d>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float hmax(
const_where_expression<simd_mask<float, simd_abi::avx512_fixed_size<8>>,
simd<float, simd_abi::avx512_fixed_size<8>>> const&
x) {
return _mm512_mask_reduce_max_ps(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float hmin(
const_where_expression<simd_mask<float, simd_abi::avx512_fixed_size<8>>,
simd<float, simd_abi::avx512_fixed_size<8>>> const&
x) {
return _mm512_mask_reduce_min_ps(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int32_t reduce(
const_where_expression<
simd_mask<std::int32_t, simd_abi::avx512_fixed_size<8>>,
simd<std::int32_t, simd_abi::avx512_fixed_size<8>>> const& x,
std::int32_t, std::plus<>) {
return _mm512_mask_reduce_add_epi32(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castsi256_si512(static_cast<__m256i>(x.impl_get_value())));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION std::int64_t reduce(
const_where_expression<
simd_mask<std::int64_t, simd_abi::avx512_fixed_size<8>>,
Expand All @@ -3251,6 +3346,16 @@ class where_expression<simd_mask<std::uint64_t, simd_abi::avx512_fixed_size<8>>,
static_cast<__m512d>(x.impl_get_value()));
}

[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION float reduce(
const_where_expression<simd_mask<float, simd_abi::avx512_fixed_size<8>>,
simd<float, simd_abi::avx512_fixed_size<8>>> const&
x,
float, std::plus<>) {
return _mm512_mask_reduce_add_ps(
static_cast<__mmask8>(x.impl_get_mask()),
_mm512_castps256_ps512(static_cast<__m256>(x.impl_get_value())));
}

} // namespace Experimental
} // namespace Kokkos

Expand Down

0 comments on commit 293f6f0

Please sign in to comment.