From ab5744887b39dafe0cd8f903885624ff56366fc8 Mon Sep 17 00:00:00 2001 From: Stephan Lachnit Date: Sat, 2 Dec 2023 10:55:26 +0100 Subject: [PATCH 1/2] Add element access via at() to std::mdspan Signed-off-by: Stephan Lachnit --- include/experimental/__p0009_bits/mdspan.hpp | 67 +++++++++++++++++++- tests/CMakeLists.txt | 1 + tests/test_mdspan_at.cpp | 34 ++++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 tests/test_mdspan_at.cpp diff --git a/include/experimental/__p0009_bits/mdspan.hpp b/include/experimental/__p0009_bits/mdspan.hpp index 23114aa5..69f4fff6 100644 --- a/include/experimental/__p0009_bits/mdspan.hpp +++ b/include/experimental/__p0009_bits/mdspan.hpp @@ -22,6 +22,9 @@ #include "trait_backports.hpp" #include "compressed_pair.hpp" +#include +#include + namespace MDSPAN_IMPL_STANDARD_NAMESPACE { template < class ElementType, @@ -219,6 +222,68 @@ class mdspan //-------------------------------------------------------------------------------- // [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element + MDSPAN_TEMPLATE_REQUIRES( + class... SizeTypes, + /* requires */ ( + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_convertible, SizeTypes, index_type) /* && ... */) && + _MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, SizeTypes) /* && ... */) && + (rank() == sizeof...(SizeTypes)) + ) + ) + constexpr reference at(SizeTypes... indices) const + { + size_t r = 0; + for (const auto& index : {indices...}) { + if (index >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + ++r; + } + return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast(std::move(indices))...)); + } + + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(const std::array& indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (indices[r] >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + + #ifdef __cpp_lib_span + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + constexpr reference at(std::span indices) const + { + for (size_t r = 0; r < indices.size(); ++r) { + if (indices[r] >= __mapping_ref().extents().extent(r)) { + throw std::out_of_range( + "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); + } + } + return __impl::template __callop(*this, indices); + } + #endif // __cpp_lib_span + #if MDSPAN_USE_BRACKET_OPERATOR MDSPAN_TEMPLATE_REQUIRES( class... SizeTypes, @@ -243,7 +308,7 @@ class mdspan ) ) MDSPAN_FORCE_INLINE_FUNCTION - constexpr reference operator[](const std::array< SizeType, rank()>& indices) const + constexpr reference operator[](const std::array& indices) const { return __impl::template __callop(*this, indices); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 14d61b2f..76799fb3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -88,6 +88,7 @@ mdspan_add_test(test_layout_preconditions ENABLE_PRECONDITIONS) mdspan_add_test(test_dims) mdspan_add_test(test_extents) +mdspan_add_test(test_mdspan_at) mdspan_add_test(test_mdspan_ctors) mdspan_add_test(test_mdspan_swap) mdspan_add_test(test_mdspan_conversion) diff --git a/tests/test_mdspan_at.cpp b/tests/test_mdspan_at.cpp new file mode 100644 index 00000000..55496812 --- /dev/null +++ b/tests/test_mdspan_at.cpp @@ -0,0 +1,34 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include +#include + +#include + + +TEST(TestMdspanAt, test_mdspan_at) { + std::array a{}; + Kokkos::mdspan> s(a.data()); + + s.at(0, 0) = 3.14; + s.at(std::array{1, 2}) = 2.72; + ASSERT_EQ(s.at(0, 0), 3.14); + ASSERT_EQ(s.at(std::array{1, 2}), 2.72); + + EXPECT_THROW(s.at(2, 3), std::out_of_range); + EXPECT_THROW(s.at(std::array{3, 1}), std::out_of_range); +} From 4cff9ced33403f8efb1362c4574793cfe8b657e4 Mon Sep 17 00:00:00 2001 From: Stephan Lachnit Date: Wed, 21 Aug 2024 18:03:45 +0200 Subject: [PATCH 2/2] Check negative indicies in at() Signed-off-by: Stephan Lachnit --- include/experimental/__p0009_bits/mdspan.hpp | 24 +++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/include/experimental/__p0009_bits/mdspan.hpp b/include/experimental/__p0009_bits/mdspan.hpp index 69f4fff6..0bac391c 100644 --- a/include/experimental/__p0009_bits/mdspan.hpp +++ b/include/experimental/__p0009_bits/mdspan.hpp @@ -24,6 +24,7 @@ #include #include +#include namespace MDSPAN_IMPL_STANDARD_NAMESPACE { template < @@ -234,7 +235,7 @@ class mdspan { size_t r = 0; for (const auto& index : {indices...}) { - if (index >= __mapping_ref().extents().extent(r)) { + if (__is_index_oor(index, __mapping_ref().extents().extent(r))) { throw std::out_of_range( "mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); @@ -254,7 +255,7 @@ class mdspan constexpr reference at(const std::array& indices) const { for (size_t r = 0; r < indices.size(); ++r) { - if (indices[r] >= __mapping_ref().extents().extent(r)) { + if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) { throw std::out_of_range( "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); @@ -274,7 +275,7 @@ class mdspan constexpr reference at(std::span indices) const { for (size_t r = 0; r < indices.size(); ++r) { - if (indices[r] >= __mapping_ref().extents().extent(r)) { + if (__is_index_oor(indices[r], __mapping_ref().extents().extent(r))) { throw std::out_of_range( "mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) + " for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}"); @@ -441,6 +442,23 @@ class mdspan MDSPAN_FORCE_INLINE_FUNCTION constexpr mapping_type const& __mapping_ref() const noexcept { return __members.__second().__first(); } MDSPAN_FORCE_INLINE_FUNCTION _MDSPAN_CONSTEXPR_14 accessor_type& __accessor_ref() noexcept { return __members.__second().__second(); } MDSPAN_FORCE_INLINE_FUNCTION constexpr accessor_type const& __accessor_ref() const noexcept { return __members.__second().__second(); } + + MDSPAN_TEMPLATE_REQUIRES( + class SizeType, + /* requires */ ( + _MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) && + _MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&) + ) + ) + MDSPAN_FORCE_INLINE_FUNCTION constexpr bool __is_index_oor(SizeType index, index_type extent) const noexcept { + // Check for negative indices + if constexpr(std::is_signed_v) { + if(index < 0) { + return true; + } + } + return static_cast(index) >= extent; + } template friend class mdspan;