Skip to content

Commit

Permalink
Re-define vector compatibility / size and generic types
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Dec 31, 2023
1 parent 9ec2ff5 commit a40d935
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 107 deletions.
113 changes: 84 additions & 29 deletions include/simsycl/detail/math_utils.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,70 +5,125 @@

namespace simsycl::detail {

struct undefined_num_elements {};

#define SIMSYCL_DETAIL_DEFINE_NUM_ELEMENTS_TRAIT(trait_name, concept_name, include_scalar) \
template<typename T> \
struct trait_name : std::conditional_t<concept_name<T> && include_scalar, std::integral_constant<int, 1>, \
undefined_num_elements> {}; \
\
template<concept_name DataT, int... Indices> \
struct trait_name<swizzled_vec<DataT, Indices...>> : std::integral_constant<int, sizeof...(Indices)> {}; \
\
template<concept_name DataT, int NumElements> \
struct trait_name<sycl::vec<DataT, NumElements>> : std::integral_constant<int, NumElements> {}; \
\
template<concept_name DataT, size_t NumElements> \
struct trait_name<sycl::marray<DataT, NumElements>> : std::integral_constant<int, static_cast<int>(NumElements)> { \
}; \
\
template<typename T> \
constexpr int trait_name##_v = trait_name<T>::value;


template<typename T>
concept SyclFloat = std::is_same_v<T, float> || std::is_same_v<T, double>
#if SIMSYCL_FEATURE_HALF_TYPE
|| std::is_same_v<T, sycl::half>
#endif
;

template<SyclFloat T>
struct num_elements<T> : std::integral_constant<int, 1> {};
template<typename T>
concept SyclInt = std::is_same_v<T, char> || std::is_same_v<T, signed char> || std::is_same_v<T, unsigned char>
|| std::is_same_v<T, short> || std::is_same_v<T, unsigned short> || std::is_same_v<T, int>
|| std::is_same_v<T, unsigned int> || std::is_same_v<T, long> || std::is_same_v<T, unsigned long>
|| std::is_same_v<T, long long> || std::is_same_v<T, unsigned long long>;

template<typename T>
concept SyclScalar = SyclFloat<T> || SyclInt<T>;

SIMSYCL_DETAIL_DEFINE_NUM_ELEMENTS_TRAIT(gen_float_num_elements, SyclFloat, true /* include scalar */)
SIMSYCL_DETAIL_DEFINE_NUM_ELEMENTS_TRAIT(gen_int_num_elements, SyclInt, true /* include scalar */)
SIMSYCL_DETAIL_DEFINE_NUM_ELEMENTS_TRAIT(non_scalar_float_num_elements, SyclFloat, false /* include scalar */)
SIMSYCL_DETAIL_DEFINE_NUM_ELEMENTS_TRAIT(non_scalar_int_num_elements, SyclInt, false /* include scalar */)

template<typename T>
concept GenFloat = gen_float_num_elements<T>::value >= 0;

template<typename T>
concept GeoFloat = gen_float_num_elements<T>::value > 0 && gen_float_num_elements<T>::value <= 4;

template<typename T>
concept GenFloat
= SyclFloat<T> || ((is_swizzle_v<T> || is_vec_v<T> || is_marray_v<T>)&&SyclFloat<typename T::element_type>);
concept GenInt = gen_int_num_elements<T>::value >= 0;

template<typename T>
concept GeoFloat = SyclFloat<T>
|| ((is_swizzle_v<T> || is_vec_v<T> || is_marray_v<T>)&&(num_elements_v<T> > 0 && num_elements_v<T> <= 4)
&& SyclFloat<typename T::value_type>);
concept Generic = GenFloat<T> || GenInt<T>;

template<typename T>
requires(is_vec_v<T> || is_swizzle_v<T> || is_marray_v<T>)
struct generic_num_elements : std::conditional_t<GenFloat<T>, gen_float_num_elements<T>, gen_int_num_elements<T>> {};

template<typename T>
constexpr int generic_num_elements_v = generic_num_elements<T>::value;

template<typename T>
concept NonScalarFloat = non_scalar_float_num_elements<T>::value >= 0;

template<typename T>
concept NonScalarInt = non_scalar_int_num_elements<T>::value >= 0;

template<typename T>
concept NonScalar = NonScalarFloat<T> || NonScalarInt<T>;

template<typename T>
struct non_scalar_num_elements
: std::conditional_t<NonScalarFloat<T>, non_scalar_float_num_elements<T>, non_scalar_int_num_elements<T>> {};

template<typename T>
constexpr int non_scalar_num_elements_v = non_scalar_num_elements<T>::value;


template<NonScalar T>
auto sum(const T &f) {
auto ret = f[0];
for(int i = 1; i < num_elements_v<T>; ++i) { ret += f[i]; }
for(int i = 1; i < non_scalar_num_elements_v<T>; ++i) { ret += f[i]; }
return ret;
}
template<SyclFloat T>

template<SyclScalar T>
auto sum(const T &f) {
return f;
}

template<typename T>
struct element_type {
using type = T;
};
template<typename T>
requires(is_vec_v<T> || is_swizzle_v<T>)
struct element_type {};

template<SyclScalar T>
struct element_type<T> {
using type = typename T::element_type;
using type = T;
};
template<typename T>
requires(is_marray_v<T>)

template<NonScalar T>
struct element_type<T> {
using type = typename T::value_type;
};

template<typename T>
using element_type_t = typename element_type<T>::type;

template<typename DataT, int NumElements>
template<typename DataT, size_t NumElements>
sycl::vec<DataT, NumElements> marray_to_vec(const sycl::marray<DataT, NumElements> &v) {
sycl::vec<DataT, NumElements> ret;
for(int i = 0; i < NumElements; ++i) { ret[i] = v[i]; }
for(size_t i = 0; i < NumElements; ++i) { ret[i] = v[i]; }
return ret;
}

template<typename VT, typename T>
requires(!is_marray_v<T>)
sycl::vec<element_type_t<VT>, num_elements_v<VT>> to_matching_vec(const T &v) {
return to_vec<element_type_t<VT>, num_elements_v<VT>>(v);
}
template<typename VT, typename T>
requires(is_marray_v<T>)
sycl::vec<element_type_t<VT>, num_elements_v<VT>> to_matching_vec(const T &v) {
return marray_to_vec<element_type_t<VT>, num_elements_v<VT>>(v);
template<Generic VT, typename T>
sycl::vec<element_type_t<VT>, generic_num_elements_v<VT>> to_matching_vec(const T &v) {
if constexpr(is_marray_v<T>) {
return marray_to_vec(v);
} else {
return to_vec<element_type_t<VT>, generic_num_elements_v<VT>>(v);
}
}

} // namespace simsycl::detail
4 changes: 2 additions & 2 deletions include/simsycl/sycl/device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class device final : public detail::reference_type<device, detail::device_state>


template<aspect Aspect>
struct any_device_has: std::false_type {};
struct any_device_has : std::false_type {};

template<aspect Aspect>
struct all_devices_have: std::false_type {};
struct all_devices_have : std::false_type {};

template<aspect A>
inline constexpr bool any_device_has_v = any_device_has<A>::value;
Expand Down
4 changes: 2 additions & 2 deletions include/simsycl/sycl/info.hh
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ struct command_end : detail::info_descriptor<uint64_t> {};

namespace simsycl::sycl::info::kernel {

struct num_args: detail::info_descriptor<uint32_t> {};
struct attributes: detail::info_descriptor<std::string> {};
struct num_args : detail::info_descriptor<uint32_t> {};
struct attributes : detail::info_descriptor<std::string> {};

} // namespace simsycl::sycl::info::kernel

Expand Down
22 changes: 7 additions & 15 deletions include/simsycl/sycl/marray.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,15 @@

namespace simsycl::detail {

template<typename DataT, typename... ArgTN>
struct marray_init_arg_traits {};
template<typename DataT, typename VecLike>
struct marray_like_num_elements {};

template<typename DataT>
struct marray_init_arg_traits<DataT> {
static constexpr size_t num_elements = 0;
};
template<typename DataT, std::convertible_to<DataT> ElementT>
struct marray_like_num_elements<DataT, ElementT> : std::integral_constant<int, 1> {};

template<typename DataT, std::convertible_to<DataT> ElementT, typename... ArgTN>
struct marray_init_arg_traits<DataT, ElementT, ArgTN...> {
static constexpr size_t num_elements = 1 + marray_init_arg_traits<DataT, ArgTN...>::num_elements;
};
template<typename DataT, int N>
struct marray_like_num_elements<DataT, sycl::vec<DataT, N>> : std::integral_constant<int, N> {};

template<typename DataT, size_t N, typename... ArgTN>
struct marray_init_arg_traits<DataT, sycl::marray<DataT, N>, ArgTN...> {
static constexpr size_t num_elements = N + marray_init_arg_traits<DataT, ArgTN...>::num_elements;
};

template<typename T>
constexpr bool is_marray_v = false;
Expand Down Expand Up @@ -56,7 +48,7 @@ class marray {
}

template<typename... ArgTN>
requires(detail::marray_init_arg_traits<DataT, ArgTN...>::num_elements == NumElements)
requires((detail::marray_like_num_elements<DataT, ArgTN>::value + ...) == NumElements)
constexpr marray(const ArgTN &...args) {
init_with_offset<0>(args...);
}
Expand Down
83 changes: 29 additions & 54 deletions include/simsycl/sycl/vec.hh
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,29 @@ struct elem {

namespace simsycl::detail {

template<typename DataT, typename... ArgTN>
struct vec_init_arg_traits {};

template<typename DataT>
struct vec_init_arg_traits<DataT> {
static constexpr int num_elements = 0;
};

template<typename DataT, std::convertible_to<DataT> ElementT, typename... ArgTN>
struct vec_init_arg_traits<DataT, ElementT, ArgTN...> {
static constexpr int num_elements = 1 + vec_init_arg_traits<DataT, ArgTN...>::num_elements;
};

template<typename DataT, int N, typename... ArgTN>
struct vec_init_arg_traits<DataT, sycl::vec<DataT, N>, ArgTN...> {
static constexpr int num_elements = N + vec_init_arg_traits<DataT, ArgTN...>::num_elements;
};

template<typename DataT, int... Indices>
class swizzled_vec;

template<typename DataT, int NumElements>
constexpr size_t vec_alignment_v = std::min(size_t{64}, sizeof(DataT) * NumElements);

template<typename DataT>
constexpr size_t vec_alignment_v<DataT, 3> = std::min(size_t{64}, sizeof(DataT) * 4);
template<typename DataT, typename VecLike>
struct vec_like_num_elements {};

template<typename DataT, std::convertible_to<DataT> ElementT>
struct vec_like_num_elements<DataT, ElementT> : std::integral_constant<int, 1> {};

template<typename DataT, int... Indices>
class swizzled_vec;


template<typename T>
struct num_elements {
static_assert(sizeof(T) < 0, "num_elements instantiated with unsupported type");
struct vec_like_num_elements<DataT, swizzled_vec<DataT, Indices...>> : std::integral_constant<int, sizeof...(Indices)> {
};

template<typename T, int NumElements>
struct num_elements<sycl::vec<T, NumElements>> : std::integral_constant<int, NumElements> {};
template<typename DataT, int N>
struct vec_like_num_elements<DataT, sycl::vec<DataT, N>> : std::integral_constant<int, N> {};

template<typename T, int... Indices>
struct num_elements<detail::swizzled_vec<T, Indices...>> : std::integral_constant<int, sizeof...(Indices)> {};
template<typename T, typename DataT>
concept VecLike = vec_like_num_elements<DataT, T>::value > 0;

template<typename T, int NumElements>
struct num_elements<sycl::marray<T, NumElements>> : std::integral_constant<int, NumElements> {};

template<typename T>
static constexpr int num_elements_v = num_elements<T>::value;
template<typename T, typename DataT, int NumElements>
concept VecCompatible
= vec_like_num_elements<DataT, T>::value == 1 || vec_like_num_elements<DataT, T>::value == NumElements;


template<int... Is>
Expand Down Expand Up @@ -130,13 +107,6 @@ template<typename T>
static constexpr bool is_vec_v = is_vec<T>::value;


template<typename V1, typename V2>
static constexpr bool is_compatible_vector_v = requires {
num_elements_v<V1> == num_elements_v<V2>;
std::is_same_v<typename V1::element_type, typename V2::element_type>;
} || std::is_same_v<typename V1::element_type, V2>;


template<typename T>
struct is_swizzle : std::false_type {};

Expand Down Expand Up @@ -397,9 +367,9 @@ class swizzled_vec {
// operators

#define SIMSYCL_DETAIL_DEFINE_SWIZZLE_BINARY_COPY_OPERATOR(op, enable_if) \
template<typename LHS, typename RHS> \
template<VecCompatible<element_type, num_elements> LHS, VecCompatible<element_type, num_elements> RHS> \
friend constexpr auto operator op(const LHS &lhs, const RHS &rhs) \
requires(enable_if && is_compatible_vector_v<swizzled_vec, RHS> && is_compatible_vector_v<swizzled_vec, LHS> \
requires(enable_if \
&& (std::is_same_v<swizzled_vec, LHS> || (std::is_same_v<swizzled_vec, RHS> && !is_swizzle_v<LHS>))) \
{ \
return to_vec<element_type, num_elements>(lhs) op to_vec<element_type, num_elements>(rhs); \
Expand All @@ -418,9 +388,9 @@ class swizzled_vec {
#undef SIMSYCL_DETAIL_DEFINE_SWIZZLE_BINARY_COPY_OPERATOR

#define SIMSYCL_DETAIL_DEFINE_SWIZZLE_BINARY_INPLACE_OPERATOR(op, enable_if) \
template<typename RHS> \
template<VecCompatible<element_type, num_elements> RHS> \
friend constexpr swizzled_vec &operator op##=(swizzled_vec && lhs, const RHS & rhs) \
requires(enable_if && allow_assign && is_compatible_vector_v<swizzled_vec, RHS>) \
requires(enable_if && allow_assign) \
{ \
return lhs = to_vec<element_type, num_elements>(lhs) op to_vec<element_type, num_elements>(rhs); \
}
Expand All @@ -441,7 +411,7 @@ class swizzled_vec {
friend constexpr auto operator op(const swizzled_vec &v) \
requires(enable_if) \
{ \
return op to_vec(v); \
return op to_vec<element_type, num_elements>(v); \
}

SIMSYCL_DETAIL_DEFINE_SWIZZLE_UNARY_COPY_OPERATOR(+, true)
Expand Down Expand Up @@ -476,10 +446,9 @@ class swizzled_vec {
#undef SIMSYCL_DETAIL_DEFINE_SWIZZLE_UNARY_POSTFIX_OPERATOR

#define SIMSYCL_DETAIL_DEFINE_SWIZZLE_COMPARISON_OPERATOR(op) \
template<typename LHS, typename RHS> \
template<VecCompatible<element_type, num_elements> LHS, VecCompatible<element_type, num_elements> RHS> \
friend constexpr auto operator op(const LHS &lhs, const RHS &rhs) \
requires(is_compatible_vector_v<swizzled_vec, RHS> && is_compatible_vector_v<swizzled_vec, LHS> \
&& (std::is_same_v<swizzled_vec, LHS> || (std::is_same_v<swizzled_vec, RHS> && !is_swizzle_v<LHS>))) \
requires(std::is_same_v<swizzled_vec, LHS> || (std::is_same_v<swizzled_vec, RHS> && !is_swizzle_v<LHS>)) \
{ \
return to_vec<element_type, num_elements>(lhs) op to_vec<element_type, num_elements>(rhs); \
}
Expand All @@ -506,6 +475,12 @@ class swizzled_vec {
ReferenceDataT *m_elems;
};

template<typename DataT, int NumElements>
constexpr size_t vec_alignment_v = std::min(size_t{64}, sizeof(DataT) * NumElements);

template<typename DataT>
constexpr size_t vec_alignment_v<DataT, 3> = std::min(size_t{64}, sizeof(DataT) * 4);

} // namespace simsycl::detail

namespace simsycl::sycl {
Expand All @@ -529,8 +504,8 @@ class alignas(detail::vec_alignment_v<DataT, NumElements>) vec {
for(int i = 0; i < NumElements; ++i) { m_elems[i] = arg; }
}

template<typename... ArgTN>
requires(detail::vec_init_arg_traits<DataT, ArgTN...>::num_elements == NumElements)
template<detail::VecLike<DataT>... ArgTN>
requires((detail::vec_like_num_elements<DataT, ArgTN>::value + ...) == NumElements)
constexpr vec(const ArgTN &...args) {
init_with_offset<0>(args...);
}
Expand Down
10 changes: 5 additions & 5 deletions test/vec_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ bool check_bool_vec(sycl::vec<bool, Dimensions> a) {
}

TEST_CASE("Compile time vector operations work as expected", "[vec]") {
CHECK(detail::num_elements_v<float> == 1);
CHECK(detail::num_elements_v<sycl::vec<float, 1>> == 1);
CHECK(detail::num_elements_v<sycl::vec<double, 2>> == 2);
CHECK(detail::num_elements_v<sycl::vec<int, 3>> == 3);
CHECK(detail::num_elements_v<sycl::vec<float, 4>> == 4);
CHECK(detail::generic_num_elements_v<float> == 1);
CHECK(detail::generic_num_elements_v<sycl::vec<float, 1>> == 1);
CHECK(detail::generic_num_elements_v<sycl::vec<double, 2>> == 2);
CHECK(detail::generic_num_elements_v<sycl::vec<int, 3>> == 3);
CHECK(detail::generic_num_elements_v<sycl::vec<float, 4>> == 4);
}

TEST_CASE("Basic vector operations work as expected", "[vec]") {
Expand Down

0 comments on commit a40d935

Please sign in to comment.