diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index 308ef2eb..16dee840 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -26,6 +26,7 @@ jobs: # To get new URL, look here: # https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#inpage-nav-6-undefined compiler_url: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/ebf5d9aa-17a7-46a4-b5df-ace004227c0e/l_dpcpp-cpp-compiler_p_2023.2.1.8_offline.sh + cxx_flags_extra: "-DMDSPAN_USE_BRACKET_OPERATOR=0" - enable_benchmark: ON - stdcxx: 14 enable_benchmark: OFF diff --git a/tests/foo_customizations.hpp b/tests/foo_customizations.hpp index 381b1ad9..b6872d54 100644 --- a/tests/foo_customizations.hpp +++ b/tests/foo_customizations.hpp @@ -193,7 +193,7 @@ class layout_foo::mapping { template MDSPAN_INLINE_FUNCTION constexpr index_type operator()(Indx0 idx0, Indx1 idx1) const noexcept { - return static_cast(idx0 * __extents.extent(0) + idx1); + return static_cast(idx0 * __extents.extent(1) + idx1); } MDSPAN_INLINE_FUNCTION static constexpr bool is_always_unique() noexcept { return true; } diff --git a/tests/test_submdspan.cpp b/tests/test_submdspan.cpp index 7ab0ceba..28d0732d 100644 --- a/tests/test_submdspan.cpp +++ b/tests/test_submdspan.cpp @@ -224,38 +224,58 @@ struct TestSubMDSpan< return Kokkos::full_extent; } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, int, SliceArgs ... slices) { - return match_expected_extents(++src_idx, sub_idx, src_ext, sub_ext, slices...); + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, int, SliceArgs ... slices) { + return check_submdspan_match(++src_idx, sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, std::pair p, SliceArgs ... slices) { - using idx_t = typename SubExtents::index_type; - return (sub_ext.extent(sub_idx)==static_cast(p.second-p.first)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...); + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, std::pair p, SliceArgs ... slices) { + using idx_t = typename SubMDSpan::index_type; + return (sub_mds.extent(sub_idx)==static_cast(p.second-p.first)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, Kokkos::strided_slice p, SliceArgs ... slices) { - using idx_t = typename SubExtents::index_type; - return (sub_ext.extent(sub_idx)==static_cast((p.extent+p.stride-1)/p.stride)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...); + using idx_t = typename SubMDSpan::index_type; + return (sub_mds.extent(sub_idx)==static_cast((p.extent+p.stride-1)/p.stride)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, Kokkos::strided_slice,std::integral_constant>, SliceArgs ... slices) { - return (sub_ext.extent(sub_idx)==0) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...); + return (sub_mds.extent(sub_idx)==0) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int src_idx, int sub_idx, SrcExtents src_ext, SubExtents sub_ext, Kokkos::full_extent_t, SliceArgs ... slices) { - return (sub_ext.extent(sub_idx)==src_ext.extent(src_idx)) && match_expected_extents(++src_idx, ++sub_idx, src_ext, sub_ext, slices...); + static bool check_submdspan_match(int src_idx, int sub_idx, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence, Kokkos::full_extent_t, SliceArgs ... slices) { + return (sub_mds.extent(sub_idx)==src_mds.extent(src_idx)) && check_submdspan_match(++src_idx, ++sub_idx, src_mds, sub_mds, std::index_sequence(), std::index_sequence(), slices...); } - template + template MDSPAN_INLINE_FUNCTION - static bool match_expected_extents(int, int, SrcExtents, SubExtents) { return true; } + static bool check_submdspan_match(int, int, SrcMDSpan src_mds, SubMDSpan sub_mds, std::index_sequence, std::index_sequence) { +#if MDSPAN_USE_BRACKET_OPERATOR + if constexpr (SrcMDSpan::rank() == 0) { + return (&src_mds[]==&sub_mds[]); + } else if constexpr (SubMDSpan::rank() == 0) { + return (&src_mds[SrcIdx...]==&sub_mds[]); + } else { + if(sub_mds.size() == 0) return true; + return (&src_mds[SrcIdx...]==&sub_mds[SubIdx...]); + } +#else + if constexpr (SrcMDSpan::rank() == 0) { + return (&src_mds()==&sub_mds()); + } else if constexpr (SubMDSpan::rank() == 0) { + return (&src_mds(SrcIdx...)==&sub_mds()); + } else { + if(sub_mds.size() == 0) return true; + return (&src_mds(SrcIdx...)==&sub_mds(SubIdx...)); + } +#endif + } static void run() { typename mds_org_t::mapping_type map(typename mds_org_t::extents_type(ConstrArgs...)); @@ -265,7 +285,7 @@ struct TestSubMDSpan< dispatch([=] _MDSPAN_HOST_DEVICE () { auto sub = Kokkos::submdspan(src, create_slice_arg(SubArgs())...); - bool match = match_expected_extents(0, 0, src.extents(), sub.extents(), create_slice_arg(SubArgs())...); + bool match = check_submdspan_match(0, 0, src, sub, std::index_sequence<>(), std::index_sequence<>(), create_slice_arg(SubArgs())...); result[0] = match?1:0; }); EXPECT_EQ(result[0], 1);