diff --git a/include/experimental/__p2630_bits/submdspan_extents.hpp b/include/experimental/__p2630_bits/submdspan_extents.hpp index c3b2f78f..779a9316 100644 --- a/include/experimental/__p2630_bits/submdspan_extents.hpp +++ b/include/experimental/__p2630_bits/submdspan_extents.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include #include "strided_slice.hpp" namespace MDSPAN_IMPL_STANDARD_NAMESPACE { @@ -52,6 +53,31 @@ template struct is_strided_slice< strided_slice> : std::true_type {}; +// Helper for identifying valid pair like things +template struct index_pair_like : std::false_type {}; + +template +struct index_pair_like, IndexType> { + static constexpr bool value = std::is_convertible_v && + std::is_convertible_v; +}; + +template +struct index_pair_like, IndexType> { + static constexpr bool value = std::is_convertible_v && + std::is_convertible_v; +}; + +template +struct index_pair_like, IndexType> { + static constexpr bool value = std::is_convertible_v; +}; + +template +struct index_pair_like, IndexType> { + static constexpr bool value = std::is_convertible_v; +}; + // first_of(slice): getting begin of slice specifier range MDSPAN_TEMPLATE_REQUIRES( class Integral, @@ -70,13 +96,19 @@ first_of(const ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t &) { MDSPAN_TEMPLATE_REQUIRES( class Slice, - /* requires */(std::is_convertible_v>) + /* requires */(index_pair_like::value) ) MDSPAN_INLINE_FUNCTION constexpr auto first_of(const Slice &i) { return std::get<0>(i); } +template +MDSPAN_INLINE_FUNCTION +constexpr auto first_of(const std::complex &i) { + return i.real(); +} + template MDSPAN_INLINE_FUNCTION constexpr OffsetType @@ -100,7 +132,7 @@ constexpr Integral MDSPAN_TEMPLATE_REQUIRES( size_t k, class Extents, class Slice, - /* requires */(std::is_convertible_v>) + /* requires */(index_pair_like::value) ) MDSPAN_INLINE_FUNCTION constexpr auto last_of(std::integral_constant, const Extents &, @@ -108,6 +140,12 @@ constexpr auto last_of(std::integral_constant, const Extents &, return std::get<1>(i); } +template +MDSPAN_INLINE_FUNCTION +constexpr auto last_of(std::integral_constant, const Extents &, const std::complex &i) { + return i.imag(); +} + // Suppress spurious warning with NVCC about no return statement. // This is a known issue in NVCC and NVC++ // Depending on the CUDA and GCC version we need both the builtin diff --git a/include/experimental/__p2630_bits/submdspan_mapping.hpp b/include/experimental/__p2630_bits/submdspan_mapping.hpp index e1390fde..69762e44 100644 --- a/include/experimental/__p2630_bits/submdspan_mapping.hpp +++ b/include/experimental/__p2630_bits/submdspan_mapping.hpp @@ -98,8 +98,7 @@ template struct is_range_slice { constexpr static bool value = std::is_same_v || - std::is_convertible_v>; + index_pair_like::value; }; template diff --git a/include/experimental/__p2642_bits/layout_padded.hpp b/include/experimental/__p2642_bits/layout_padded.hpp index ef10e0ed..1502489a 100644 --- a/include/experimental/__p2642_bits/layout_padded.hpp +++ b/include/experimental/__p2642_bits/layout_padded.hpp @@ -95,6 +95,7 @@ struct padded_extent { using static_array_type = typename static_array_type_for_padded_extent< padding_value, _Extents, _ExtentToPadIdx, _Extents::rank()>::type; + MDSPAN_INLINE_FUNCTION static constexpr auto static_value() { return static_array_type::static_value(0); } MDSPAN_INLINE_FUNCTION @@ -203,7 +204,7 @@ class layout_left_padded::mapping { } public: -#if !MDSPAN_HAS_CXX_20 +#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__) MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() : mapping(extents_type{}) @@ -347,7 +348,7 @@ class layout_left_padded::mapping { MDSPAN_INLINE_FUNCTION constexpr mapping(const _Mapping &other_mapping) noexcept : padded_stride(padded_stride_type::init_padding( - other_mapping.extents(), + static_cast(other_mapping.extents()), other_mapping.extents().extent(extent_to_pad_idx))), exts(other_mapping.extents()) {} @@ -566,7 +567,7 @@ class layout_right_padded::mapping { } public: -#if !MDSPAN_HAS_CXX_20 +#if !MDSPAN_HAS_CXX_20 || defined(__NVCC__) MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mapping() : mapping(extents_type{}) @@ -707,7 +708,7 @@ class layout_right_padded::mapping { MDSPAN_INLINE_FUNCTION constexpr mapping(const _Mapping &other_mapping) noexcept : padded_stride(padded_stride_type::init_padding( - other_mapping.extents(), + static_cast(other_mapping.extents()), other_mapping.extents().extent(extent_to_pad_idx))), exts(other_mapping.extents()) {} diff --git a/tests/test_layout_padded_left.cpp b/tests/test_layout_padded_left.cpp index 2e7fe546..88a36a00 100644 --- a/tests/test_layout_padded_left.cpp +++ b/tests/test_layout_padded_left.cpp @@ -296,6 +296,8 @@ TEST(LayoutLeftTests, construction) ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).padded_stride.value(0)), 0); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).padded_stride.value(0)), 4); + ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>(Kokkos::extents(4))).extents()), (Kokkos::extents())); + ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>(Kokkos::extents(4))).padded_stride.value(0)), 4); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).padded_stride.value(0)), 4); @@ -311,6 +313,8 @@ TEST(LayoutLeftTests, construction) ASSERT_EQ(KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).extents(), Kokkos::extents()); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).padded_stride.value(0)), 0); + ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded<4>::mapping>(Kokkos::dextents(3))).extents()), (Kokkos::extents())); + ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded<4>::mapping>(Kokkos::dextents(3))).padded_stride.value(0)), 0); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded::mapping>({}, 4)).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_left_padded<4>::mapping>(KokkosEx::layout_right_padded::mapping>({}, 4)).padded_stride.value(0)), 0); diff --git a/tests/test_layout_padded_right.cpp b/tests/test_layout_padded_right.cpp index 2a600e28..d5e8300c 100644 --- a/tests/test_layout_padded_right.cpp +++ b/tests/test_layout_padded_right.cpp @@ -304,6 +304,8 @@ TEST(LayoutrightTests, construction) ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).padded_stride.value(0)), 0); ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).padded_stride.value(0)), 8); + ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>(Kokkos::extents(7))).extents()), (Kokkos::extents())); + ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>(Kokkos::extents(7))).padded_stride.value(0)), 8); ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_right_padded::mapping>(KokkosEx::layout_right_padded<4>::mapping>()).padded_stride.value(0)), 8); @@ -311,6 +313,8 @@ TEST(LayoutrightTests, construction) ASSERT_EQ(KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).extents(), Kokkos::extents()); ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>()).padded_stride.value(0)), 0); + ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>(Kokkos::dextents(3))).extents()), (Kokkos::extents())); + ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded<4>::mapping>(Kokkos::dextents(3))).padded_stride.value(0)), 0); ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded::mapping>({}, 4)).extents()), (Kokkos::extents())); ASSERT_EQ((KokkosEx::layout_right_padded<4>::mapping>(KokkosEx::layout_left_padded::mapping>({}, 4)).padded_stride.value(0)), 0); diff --git a/tests/test_submdspan.cpp b/tests/test_submdspan.cpp index a0f4e043..184326d8 100644 --- a/tests/test_submdspan.cpp +++ b/tests/test_submdspan.cpp @@ -140,6 +140,7 @@ using submdspan_test_types = // layout_right to layout_right Check Extents Preservation , std::tuple, args_t<10>, Kokkos::extents, Kokkos::full_extent_t> , std::tuple, args_t<10>, Kokkos::extents, std::pair> + , std::tuple, args_t<10>, Kokkos::extents, std::complex> , std::tuple, args_t<10>, Kokkos::extents, int> , std::tuple, args_t<10,20>, Kokkos::extents, Kokkos::full_extent_t, Kokkos::full_extent_t> , std::tuple, args_t<10,20>, Kokkos::extents, std::pair, Kokkos::full_extent_t> @@ -274,6 +275,10 @@ struct TestSubMDSpan< return std::pair(1,3); } MDSPAN_INLINE_FUNCTION + static auto create_slice_arg(std::complex) { + return std::complex{1.,3.}; + } + MDSPAN_INLINE_FUNCTION static auto create_slice_arg(Kokkos::strided_slice) { return Kokkos::strided_slice{1,3,2}; } @@ -300,6 +305,12 @@ struct TestSubMDSpan< } template MDSPAN_INLINE_FUNCTION + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, std::complex p, SliceArgs ... slices) { + using idx_t = typename SubMDSpan::index_type; + return (sub_mds.extent(sub_idx)==static_cast(p.imag()-p.real())) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); + } + template + MDSPAN_INLINE_FUNCTION static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, Kokkos::strided_slice p, SliceArgs ... slices) { using idx_t = typename SubMDSpan::index_type;