From 063dfe1572089f223079829905adc81da6c7011b Mon Sep 17 00:00:00 2001 From: AuroraPerego Date: Mon, 11 Dec 2023 18:53:51 +0100 Subject: [PATCH] take SYCL implementation from DPCT --- include/alpaka/warp/WarpGenericSycl.hpp | 64 +++++++++++++------------ 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/include/alpaka/warp/WarpGenericSycl.hpp b/include/alpaka/warp/WarpGenericSycl.hpp index ebb88979fa08..425d97a25859 100644 --- a/include/alpaka/warp/WarpGenericSycl.hpp +++ b/include/alpaka/warp/WarpGenericSycl.hpp @@ -1,5 +1,11 @@ /* Copyright 2023 Jan Stephan, Luca Ferragina, Andrea Bocci, Aurora Perego * SPDX-License-Identifier: MPL-2.0 + * + * The implementations of Shfl::shfl(), ShflUp::shfl_up(), ShflDown::shfl_down() and ShflXor::shfl_xor() are derived + * from Intel DPCT. + * Copyright (C) Intel Corporation. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * See https://llvm.org/LICENSE.txt for license information. */ #pragma once @@ -123,12 +129,9 @@ namespace alpaka::warp::trait The first starts at sub-group index 0 and the second at sub-group index 16. For srcLane = 4 the first subdivision will access the value at sub-group index 4 and the second at sub-group index 20. */ auto const actual_group = warp.m_item_warp.get_sub_group(); - auto const actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = static_cast(srcLane + actual_group_id * width); - auto const src = sycl::id<1>{actual_src_id}; - - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const start_index = actual_group.get_local_linear_id() / w * w; + return sycl::select_from_group(actual_group, value, start_index + static_cast(srcLane) % w); } }; @@ -142,15 +145,16 @@ namespace alpaka::warp::trait std::uint32_t offset, /* must be the same for all work-items in the group */ std::int32_t width) { - std::int32_t offset_int = static_cast(offset); auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = actual_item_id - offset_int; - auto const src = actual_src_id >= actual_group_id * width - ? sycl::id<1>{static_cast(actual_src_id)} - : sycl::id<1>{static_cast(actual_item_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const start_index = id / w * w; + T result = sycl::shift_group_right(actual_group, value, offset); + if((id - start_index) < offset) + { + result = value; + } + return result; } }; @@ -164,15 +168,16 @@ namespace alpaka::warp::trait std::uint32_t offset, std::int32_t width) { - std::int32_t offset_int = static_cast(offset); auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_group_id = actual_item_id / width; - auto const actual_src_id = actual_item_id + offset_int; - auto const src = actual_src_id < (actual_group_id + 1) * width - ? sycl::id<1>{static_cast(actual_src_id)} - : sycl::id<1>{static_cast(actual_item_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const end_index = (id / w + 1) * w; + T result = sycl::shift_group_left(actual_group, value, offset); + if((id + offset) >= end_index) + { + result = value; + } + return result; } }; @@ -180,17 +185,14 @@ namespace alpaka::warp::trait struct ShflXor> { template - static auto shfl_xor( - warp::WarpGenericSycl const& warp, - T value, - std::int32_t mask, - std::int32_t /*width*/) + static auto shfl_xor(warp::WarpGenericSycl const& warp, T value, std::int32_t mask, std::int32_t width) { auto const actual_group = warp.m_item_warp.get_sub_group(); - auto actual_item_id = static_cast(actual_group.get_local_linear_id()); - auto const actual_src_id = actual_item_id ^ mask; - auto const src = sycl::id<1>{static_cast(actual_src_id)}; - return sycl::select_from_group(actual_group, value, src); + std::uint32_t const w = static_cast(width); + std::uint32_t const id = actual_group.get_local_linear_id(); + std::uint32_t const start_index = id / w * w; + std::uint32_t const target_offset = (id % w) ^ static_cast(mask); + return sycl::select_from_group(actual_group, value, target_offset < w ? start_index + target_offset : id); } }; } // namespace alpaka::warp::trait