Skip to content

Commit

Permalink
Add element access via at() to std::mdspan
Browse files Browse the repository at this point in the history
Signed-off-by: Stephan Lachnit <[email protected]>
  • Loading branch information
stephanlachnit committed Aug 22, 2024
1 parent 98a12b0 commit ab57448
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 1 deletion.
67 changes: 66 additions & 1 deletion include/experimental/__p0009_bits/mdspan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
#include "trait_backports.hpp"
#include "compressed_pair.hpp"

#include <stdexcept>
#include <string>

namespace MDSPAN_IMPL_STANDARD_NAMESPACE {
template <
class ElementType,
Expand Down Expand Up @@ -219,6 +222,68 @@ class mdspan
//--------------------------------------------------------------------------------
// [mdspan.basic.mapping], mdspan mapping domain multidimensional index to access codomain element

MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
/* requires */ (
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_convertible, SizeTypes, index_type) /* && ... */) &&
_MDSPAN_FOLD_AND(_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, SizeTypes) /* && ... */) &&
(rank() == sizeof...(SizeTypes))
)
)
constexpr reference at(SizeTypes... indices) const
{
size_t r = 0;
for (const auto& index : {indices...}) {
if (index >= __mapping_ref().extents().extent(r)) {
throw std::out_of_range(
"mdspan::at(...," + std::to_string(index) + ",...) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
++r;
}
return __accessor_ref().access(__ptr_ref(), __mapping_ref()(static_cast<index_type>(std::move(indices))...));
}

MDSPAN_TEMPLATE_REQUIRES(
class SizeType,
/* requires */ (
_MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&)
)
)
constexpr reference at(const std::array<SizeType, rank()>& indices) const
{
for (size_t r = 0; r < indices.size(); ++r) {
if (indices[r] >= __mapping_ref().extents().extent(r)) {
throw std::out_of_range(
"mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
}
return __impl::template __callop<reference>(*this, indices);
}

#ifdef __cpp_lib_span
MDSPAN_TEMPLATE_REQUIRES(
class SizeType,
/* requires */ (
_MDSPAN_TRAIT(std::is_convertible, const SizeType&, index_type) &&
_MDSPAN_TRAIT(std::is_nothrow_constructible, index_type, const SizeType&)
)
)
constexpr reference at(std::span<SizeType, rank()> indices) const
{
for (size_t r = 0; r < indices.size(); ++r) {
if (indices[r] >= __mapping_ref().extents().extent(r)) {
throw std::out_of_range(
"mdspan::at({...," + std::to_string(indices[r]) + ",...}) out-of-range at rank index " + std::to_string(r) +
" for mdspan with extent {...," + std::to_string(__mapping_ref().extents().extent(r)) + ",...}");
}
}
return __impl::template __callop<reference>(*this, indices);
}
#endif // __cpp_lib_span

#if MDSPAN_USE_BRACKET_OPERATOR
MDSPAN_TEMPLATE_REQUIRES(
class... SizeTypes,
Expand All @@ -243,7 +308,7 @@ class mdspan
)
)
MDSPAN_FORCE_INLINE_FUNCTION
constexpr reference operator[](const std::array< SizeType, rank()>& indices) const
constexpr reference operator[](const std::array<SizeType, rank()>& indices) const
{
return __impl::template __callop<reference>(*this, indices);
}
Expand Down
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ mdspan_add_test(test_layout_preconditions ENABLE_PRECONDITIONS)

mdspan_add_test(test_dims)
mdspan_add_test(test_extents)
mdspan_add_test(test_mdspan_at)
mdspan_add_test(test_mdspan_ctors)
mdspan_add_test(test_mdspan_swap)
mdspan_add_test(test_mdspan_conversion)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_mdspan_at.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#include <array>
#include <mdspan/mdspan.hpp>

#include <gtest/gtest.h>


TEST(TestMdspanAt, test_mdspan_at) {
std::array<double, 6> a{};
Kokkos::mdspan<double, Kokkos::extents<size_t, 2, 3>> s(a.data());

s.at(0, 0) = 3.14;
s.at(std::array<int, 2>{1, 2}) = 2.72;
ASSERT_EQ(s.at(0, 0), 3.14);
ASSERT_EQ(s.at(std::array<int, 2>{1, 2}), 2.72);

EXPECT_THROW(s.at(2, 3), std::out_of_range);
EXPECT_THROW(s.at(std::array<int, 2>{3, 1}), std::out_of_range);
}

0 comments on commit ab57448

Please sign in to comment.