Skip to content

Commit

Permalink
Modified to use identity element based on input mask
Browse files Browse the repository at this point in the history
  • Loading branch information
ldh4 committed Oct 6, 2024
1 parent 6c7b30e commit d111add
Show file tree
Hide file tree
Showing 7 changed files with 628 additions and 557 deletions.
393 changes: 243 additions & 150 deletions simd/src/Kokkos_SIMD_AVX512.hpp

Large diffs are not rendered by default.

149 changes: 72 additions & 77 deletions simd/src/Kokkos_SIMD_Common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,19 +317,19 @@ KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION simd<T, Abi>& operator<<=(
// fallback implementations of reductions across simd_mask:

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool all_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool all_of(
simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(true);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool any_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool any_of(
simd_mask<T, Abi> const& a) {
return a != simd_mask<T, Abi>(false);
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION bool none_of(
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr bool none_of(
simd_mask<T, Abi> const& a) {
return a == simd_mask<T, Abi>(false);
}
Expand All @@ -352,93 +352,88 @@ template <typename T>
template <class T, class Abi, class BinaryOperation = std::plus<>>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, BinaryOperation binary_op = {}) {
auto v = where(true, x);
return reduce(v, binary_op);
}

template <class T, class Abi, class BinaryOperation>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
T identity_element, BinaryOperation binary_op) {
if (none_of(mask)) return identity_element;
auto v = where(mask, x);
return reduce(v, binary_op);
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
std::plus<> binary_op = {}) noexcept {
return reduce(x, mask, T(0), binary_op);
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
std::multiplies<> binary_op) noexcept {
return reduce(x, mask, T(0), binary_op);
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
std::bit_and<> binary_op) noexcept {
return reduce(x, mask, 0, binary_op);
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
std::bit_or<> binary_op) noexcept {
return reduce(x, mask, 0, binary_op);
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
std::bit_xor<> binary_op) noexcept {
return reduce(x, mask, 0, binary_op);
}
return reduce(x, simd<T, Abi>::mask_type(true), T(0), binary_op);
}

// template <class T, class Abi, class BinaryOperation>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// T identity_element, BinaryOperation binary_op) {
// if (none_of(mask)) return identity_element;
// return reduce(x, mask, identity_element, binary_op);
// }

// template <
// class T, class Abi,
// std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// std::plus<> binary_op = {}) noexcept {
// return reduce(x, mask, T(0), binary_op);
// }

// template <
// class T, class Abi,
// std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// std::multiplies<> binary_op) noexcept {
// return reduce(x, mask, T(0), binary_op);
// }

// template <
// class T, class Abi,
// std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// std::bit_and<> binary_op) noexcept {
// return reduce(x, mask, 0, binary_op);
// }

// template <
// class T, class Abi,
// std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// std::bit_or<> binary_op) noexcept {
// return reduce(x, mask, 0, binary_op);
// }

// template <
// class T, class Abi,
// std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
// const simd<T, Abi>& x, const typename simd<T, Abi>::mask_type& mask,
// std::bit_xor<> binary_op) noexcept {
// return reduce(x, mask, 0, binary_op);
// }

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min(
const simd<T, Abi>& x) noexcept {
auto v = where(true, x);
return reduce_min(v);
return reduce_min(x, simd<T, Abi>::mask_type(true));
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min(
const simd<T, Abi>& x,
const typename simd<T, Abi>::mask_type& mask) noexcept {
auto v = where(mask, x);
return reduce_min(v);
}
// template <class T, class Abi>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min(
// const simd<T, Abi>& x,
// const typename simd<T, Abi>::mask_type& mask) noexcept {
// return reduce_min(x, mask);
// }

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max(
const simd<T, Abi>& x) noexcept {
auto v = where(true, x);
return reduce_max(v);
return reduce_max(x, simd<T, Abi>::mask_type(true));
}

template <class T, class Abi>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max(
const simd<T, Abi>& x,
const typename simd<T, Abi>::mask_type& mask) noexcept {
auto v = where(mask, x);
return reduce_max(v);
}
// template <class T, class Abi>
// [[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max(
// const simd<T, Abi>& x,
// const typename simd<T, Abi>::mask_type& mask) noexcept {
// return reduce_max(x, mask);
// }

} // namespace Experimental
} // namespace Kokkos
Expand Down
38 changes: 13 additions & 25 deletions simd/src/Kokkos_SIMD_Common_Math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,9 @@ hmax(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
template <
typename T, typename Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce_min(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::min();
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_min(
simd<T, Abi> const& v, typename simd<T, Abi>::mask_type const& m) {
auto result = Kokkos::reduction_identity<T>::min();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::min(result, v[i]);
}
Expand All @@ -79,11 +77,9 @@ reduce_min(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce_max(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = Kokkos::reduction_identity<T>::max();
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce_max(
simd<T, Abi> const& v, typename simd<T, Abi>::mask_type const& m) {
auto result = Kokkos::reduction_identity<T>::max();
for (std::size_t i = 0; i < v.size(); ++i) {
if (m[i]) result = Kokkos::max(result, v[i]);
}
Expand All @@ -93,27 +89,19 @@ reduce_max(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x) {
template <
class T, class Abi, class BinaryOperation = std::plus<>,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x,
BinaryOperation op = {}) {
auto const& v = x.impl_get_value();
auto const& m = x.impl_get_mask();
auto result = v[0];
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION constexpr T reduce(
simd<T, Abi> const& v, typename simd<T, Abi>::mask_type const& m,
T identity, BinaryOperation op = {}) {
if (none_of(m)) {
return identity;
}
auto result = v[0];
for (std::size_t i = 1; i < v.size(); ++i) {
if (m[i]) result = op(result, v[i]);
}
return result;
}

template <
class T, class Abi,
std::enable_if_t<!std::is_same_v<Abi, simd_abi::scalar>, bool> = false>
[[nodiscard]] KOKKOS_IMPL_HOST_FORCEINLINE_FUNCTION T
reduce(const_where_expression<simd_mask<T, Abi>, simd<T, Abi>> const& x, T,
std::plus<>) {
return reduce(x, std::plus<>());
}

} // namespace Experimental

template <class T, class Abi>
Expand Down
Loading

0 comments on commit d111add

Please sign in to comment.