Skip to content

Commit

Permalink
[SYCL] Don't use sycl::vec::vector_t in built-in functions
Browse files Browse the repository at this point in the history
* `vector_t` is expected to be removed in
  KhronosGroup/SYCL-Docs#676
* we aren't required to use it here
* `operator vector_t`/`vec(vector_t)` are implemented as a simple
  `bit_cast` anyway, can use it explicitly as well.
  • Loading branch information
aelovikov-intel committed Jan 6, 2025
1 parent 62ce674 commit 450820c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 37 deletions.
11 changes: 1 addition & 10 deletions sycl/include/sycl/detail/builtins/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ template <typename T> auto convert_arg(T &&x) {
using result_type = std::conditional_t<N == 1, converted_elem_type,
converted_elem_type
__attribute__((ext_vector_type(N)))>;
// TODO: We should have this bit_cast impl inside vec::convert.
return bit_cast<result_type>(static_cast<typename no_cv_ref::vector_t>(x));
return bit_cast<result_type>(x);
} else if constexpr (is_swizzle_v<no_cv_ref>) {
return convert_arg(simplify_if_swizzle_t<no_cv_ref>{x});
} else {
Expand All @@ -104,14 +103,6 @@ template <typename T> auto convert_arg(T &&x) {
return convertToOpenCLType(std::forward<T>(x));
}
}

template <typename RetTy, typename T> auto convert_result(T &&x) {
if constexpr (is_vec_v<RetTy>) {
return bit_cast<typename RetTy::vector_t>(x);
} else {
return std::forward<T>(x);
}
}
#endif
} // namespace builtins

Expand Down
32 changes: 7 additions & 25 deletions sycl/include/sycl/detail/builtins/integer_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ BUILTIN_CREATE_ENABLER(builtin_enable_suint32, default_ret_type,
NUM_ARGS, NAME, builtin_enable_integer_t, [](auto... xs) { \
using ret_ty = \
detail::builtin_enable_integer_t<NUM_ARGS##_TEMPLATE_TYPE>; \
return detail::builtins::convert_result<ret_ty>( \
__spirv_ocl_##NAME(xs...)); \
return bit_cast<ret_ty>(__spirv_ocl_##NAME(xs...)); \
})
#else
#define BUILTIN_GENINT(NUM_ARGS, NAME) \
Expand All @@ -54,11 +53,10 @@ BUILTIN_CREATE_ENABLER(builtin_enable_suint32, default_ret_type,
NUM_ARGS, NAME, builtin_enable_integer_t, [](auto... xs) { \
using ret_ty = \
detail::builtin_enable_integer_t<NUM_ARGS##_TEMPLATE_TYPE>; \
using detail::builtins::convert_result; \
if constexpr (std::is_signed_v<detail::get_elem_type_t<T0>>) \
return convert_result<ret_ty>(__spirv_ocl_s_##NAME(xs...)); \
return bit_cast<ret_ty>(__spirv_ocl_s_##NAME(xs...)); \
else \
return convert_result<ret_ty>(__spirv_ocl_u_##NAME(xs...)); \
return bit_cast<ret_ty>(__spirv_ocl_u_##NAME(xs...)); \
})
#else
#define BUILTIN_GENINT_SU(NUM_ARGS, NAME) BUILTIN_GENINT(NUM_ARGS, NAME)
Expand All @@ -67,15 +65,14 @@ BUILTIN_CREATE_ENABLER(builtin_enable_suint32, default_ret_type,
#if __SYCL_DEVICE_ONLY__
DEVICE_IMPL_TEMPLATE(ONE_ARG, abs, builtin_enable_integer_t, [](auto x) {
using ret_ty = detail::builtin_enable_integer_t<T0>;
using detail::builtins::convert_result;
if constexpr (std::is_signed_v<detail::get_elem_type_t<T0>>)
// SPIR-V builtin returns unsigned type, SYCL's return type is signed
// with the following restriction:
// > The behavior is undefined if the result cannot be represented by
// > the return type
return convert_result<ret_ty>(bit_cast<T0>(__spirv_ocl_s_abs(x)));
return bit_cast<ret_ty>(__spirv_ocl_s_abs(x));
else
return convert_result<ret_ty>(__spirv_ocl_u_abs(x));
return bit_cast<ret_ty>(__spirv_ocl_u_abs(x));
})
#else
BUILTIN_GENINT_SU(ONE_ARG, abs)
Expand All @@ -87,25 +84,10 @@ BUILTIN_GENINT_SU(TWO_ARGS, add_sat)
DEVICE_IMPL_TEMPLATE(
TWO_ARGS, abs_diff, builtin_enable_integer_t, [](auto... xs) {
using ret_ty = detail::builtin_enable_integer_t<T0>;
using detail::builtins::convert_result;
if constexpr (std::is_signed_v<detail::get_elem_type_t<T0>>) {
// SPIRV built-in returns [vector of] unsigned type(s).
auto ret = __spirv_ocl_s_abs_diff(xs...);
if constexpr (detail::is_vec_v<T0>) {
// SYCL 2020 revision 8's abs_diff returns T0 (or corresponding vec in
// case of a swizzle). The only way to produce signed ext_vector_type
// from unsigned is with C-style case. Also note that element type of
// sycl::vec and ext_vector_type might be different, e.g.
// sycl::vec<char, N>::vector_t is
// signed char __attribute__((ext_vector_type(N))).
//
// TODO: Shouldn't be different from "abs" above.
return convert_result<ret_ty>((typename T0::vector_t)(ret));
} else {
return convert_result<ret_ty>(ret);
}
return bit_cast<ret_ty>(__spirv_ocl_s_abs_diff(xs...));
} else {
return convert_result<ret_ty>(__spirv_ocl_u_abs_diff(xs...));
return bit_cast<ret_ty>(__spirv_ocl_u_abs_diff(xs...));
}
})
#else
Expand Down
3 changes: 1 addition & 2 deletions sycl/include/sycl/detail/builtins/relational_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,7 @@ DEVICE_IMPL_TEMPLATE(
THREE_ARGS, bitselect, builtin_enable_bitselect_t, [](auto... xs) {
using ret_ty =
detail::builtin_enable_bitselect_t<THREE_ARGS_TEMPLATE_TYPE>;
using detail::builtins::convert_result;
return convert_result<ret_ty>(__spirv_ocl_bitselect(xs...));
return bit_cast<ret_ty>(__spirv_ocl_bitselect(xs...));
})
#else
HOST_IMPL_TEMPLATE(THREE_ARGS, bitselect, builtin_enable_bitselect_t, rel,
Expand Down

0 comments on commit 450820c

Please sign in to comment.