Skip to content

Commit

Permalink
[SYCL] Widen (u)int8/16 to (u)int32 and half to float in group_broadc…
Browse files Browse the repository at this point in the history
…ast (#5110)

CPU device does not yet support the (u)int8/16 and half versions.

- Add FIXMEs.
- Bitcast half to int16_t (and then widen to int32_t) to keep the precision.
- Refactor the widening code into a separate helper.
- Add tests for all 3 group_broadcast overloads.
  • Loading branch information
dnmokhov authored Dec 19, 2021
1 parent e83cb19 commit 1f3f9b9
Show file tree
Hide file tree
Showing 2 changed files with 569 additions and 438 deletions.
37 changes: 16 additions & 21 deletions sycl/include/CL/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ template <typename Group> bool GroupAny(bool pred) {
}

// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
// FIXME: Do not special-case for half once all backends support all data types.
template <typename T>
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value>;
using is_native_broadcast = bool_constant<detail::is_arithmetic<T>::value &&
!std::is_same<T, half>::value>;

template <typename T, typename IdT = size_t>
using EnableIfNativeBroadcast = detail::enable_if_t<
Expand Down Expand Up @@ -121,6 +123,13 @@ template <typename T, typename IdT = size_t>
using EnableIfGenericBroadcast = detail::enable_if_t<
is_generic_broadcast<T>::value && std::is_integral<IdT>::value, T>;

// FIXME: Disable widening once all backends support all data types.
template <typename T>
using WidenOpenCLTypeTo32_t = conditional_t<
std::is_same<T, cl_char>() || std::is_same<T, cl_short>(), cl_int,
conditional_t<std::is_same<T, cl_uchar>() || std::is_same<T, cl_ushort>(),
cl_uint, T>>;

// Broadcast with scalar local index
// Work-group supports any integral type
// Sub-group currently supports only uint32_t
Expand All @@ -133,21 +142,17 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using OCLT = detail::ConvertToOpenCLType_t<T>;
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename Group, typename T, typename IdT>
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast(T x, IdT local_id) {
using GroupIdT = typename GroupId<Group>::type;
GroupIdT GroupLocalId = static_cast<GroupIdT>(local_id);
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, typename IdT>
Expand All @@ -173,31 +178,21 @@ EnableIfNativeBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
}
using IdT = vec<size_t, Dimensions>;
using OCLT = detail::ConvertToOpenCLType_t<T>;
using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
IdT VecId;
for (int i = 0; i < Dimensions; ++i) {
VecId[i] = local_id[Dimensions - i - 1];
}
OCLT OCLX = detail::convertDataToType<T, OCLT>(x);
WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
return __spirv_GroupBroadcast(group_scope<Group>::value, OCLX, OCLId);
}
template <typename Group, typename T, int Dimensions>
EnableIfBitcastBroadcast<T> GroupBroadcast(T x, id<Dimensions> local_id) {
if (Dimensions == 1) {
return GroupBroadcast<Group>(x, local_id[0]);
}
using IdT = vec<size_t, Dimensions>;
using BroadcastT = ConvertToNativeBroadcastType_t<T>;
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
IdT VecId;
for (int i = 0; i < Dimensions; ++i) {
VecId[i] = local_id[Dimensions - i - 1];
}
auto BroadcastX = bit_cast<BroadcastT>(x);
OCLIdT OCLId = detail::convertDataToType<IdT, OCLIdT>(VecId);
BroadcastT Result =
__spirv_GroupBroadcast(group_scope<Group>::value, BroadcastX, OCLId);
BroadcastT Result = GroupBroadcast<Group>(BroadcastX, local_id);
return bit_cast<T>(Result);
}
template <typename Group, typename T, int Dimensions>
Expand Down
Loading

0 comments on commit 1f3f9b9

Please sign in to comment.