From 7b92bcd8005a5d85cc6009ffd9f1ef935b77cbe9 Mon Sep 17 00:00:00 2001 From: AuroraPerego Date: Mon, 11 Dec 2023 18:53:51 +0100 Subject: [PATCH] take SYCL implementation from DPCT --- include/alpaka/warp/WarpGenericSycl.hpp | 58 ++++++++++++------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/include/alpaka/warp/WarpGenericSycl.hpp b/include/alpaka/warp/WarpGenericSycl.hpp index ebb88979fa08..c7b8821c5abc 100644 --- a/include/alpaka/warp/WarpGenericSycl.hpp +++ b/include/alpaka/warp/WarpGenericSycl.hpp @@ -123,12 +123,9 @@ namespace alpaka::warp::trait The first starts at sub-group index 0 and the second at sub-group index 16. For srcLane = 4 the first subdivision will access the value at sub-group index 4 and the second at sub-group index 20. */ auto const actual_group = warp.m_item_warp.get_sub_group(); - auto const actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = static_cast(srcLane + actual_group_id * width); - auto const src = sycl::id<1>{actual_src_id}; - - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const start_index = actual_group.get_local_linear_id() / w * w; + return sycl::select_from_group(actual_group, value, start_index + static_cast(srcLane) % w); } }; @@ -142,15 +139,16 @@ namespace alpaka::warp::trait std::uint32_t offset, /* must be the same for all work-items in the group */ std::int32_t width) { - std::int32_t offset_int = static_cast(offset); auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = actual_item_id - offset_int; - auto const src = actual_src_id >= actual_group_id * width - ? sycl::id<1>{static_cast(actual_src_id)} - : sycl::id<1>{static_cast(actual_item_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const start_index = id / w * w; + T result = sycl::shift_group_right(actual_group, value, offset); + if((id - start_index) < offset) + { + result = value; + } + return result; } }; @@ -164,15 +162,16 @@ namespace alpaka::warp::trait std::uint32_t offset, std::int32_t width) { - std::int32_t offset_int = static_cast(offset); auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = actual_item_id + offset_int; - auto const src = actual_src_id < (actual_group_id + 1) * width - ? sycl::id<1>{static_cast(actual_src_id)} - : sycl::id<1>{static_cast(actual_item_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const end_index = (id / w + 1) * w; + T result = sycl::shift_group_left(actual_group, value, offset); + if((id + offset) >= end_index) + { + result = value; + } + return result; } }; @@ -180,17 +179,14 @@ namespace alpaka::warp::trait struct ShflXor> { template - static auto shfl_xor( - warp::WarpGenericSycl const& warp, - T value, - std::int32_t mask, - std::int32_t /*width*/) + static auto shfl_xor(warp::WarpGenericSycl const& warp, T value, std::int32_t mask, std::int32_t width) { auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_src_id = actual_item_id ^ mask; - auto const src = sycl::id<1>{static_cast(actual_src_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const start_index = id / w * w; + std::uint32_t const target_offset = (id % w) ^ static_cast(mask); + return sycl::select_from_group(actual_group, value, target_offset < w ? start_index + target_offset : id); } }; } // namespace alpaka::warp::trait