Skip to content

Commit e109ce3

Browse files
committed
Adds sycl::vec overloads to abs, cos, expm1, log, log1p, and sqrt
1 parent 5ec9fd5 commit e109ce3

File tree

6 files changed

+122
-6
lines changed

6 files changed

+122
-6
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/abs.hpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,15 @@ namespace py = pybind11;
5252
namespace td_ns = dpctl::tensor::type_dispatch;
5353

5454
using dpctl::tensor::type_utils::is_complex;
55+
using dpctl::tensor::type_utils::vec_cast;
5556

5657
template <typename argT, typename resT> struct AbsFunctor
5758
{
5859

5960
using is_constant = typename std::false_type;
6061
// constexpr resT constant_value = resT{};
61-
using supports_vec = typename std::false_type;
62+
using supports_vec = typename std::negation<
63+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6264
using supports_sg_loadstore = typename std::negation<
6365
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466

@@ -87,6 +89,40 @@ template <typename argT, typename resT> struct AbsFunctor
8789
}
8890
}
8991

92+
template <int vec_sz>
93+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
94+
{
95+
if constexpr (std::is_integral<argT>::value) {
96+
if constexpr (std::is_same_v<argT, bool> ||
97+
std::is_unsigned<argT>::value) {
98+
return in;
99+
}
100+
else {
101+
auto const &res_vec = sycl::abs(in);
102+
using deducedT = typename std::remove_cv_t<
103+
std::remove_reference_t<decltype(res_vec)>>::element_type;
104+
if constexpr (std::is_same_v<resT, deducedT>) {
105+
return res_vec;
106+
}
107+
else {
108+
109+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
110+
}
111+
}
112+
}
113+
else {
114+
auto const &res_vec = sycl::fabs(in);
115+
using deducedT = typename std::remove_cv_t<
116+
std::remove_reference_t<decltype(res_vec)>>::element_type;
117+
if constexpr (std::is_same_v<resT, deducedT>) {
118+
return res_vec;
119+
}
120+
else {
121+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
122+
}
123+
}
124+
}
125+
90126
private:
91127
template <typename realT> realT cabs(std::complex<realT> const &z) const
92128
{

dpctl/tensor/libtensor/include/kernels/elementwise_functions/cos.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
template <typename argT, typename resT> struct CosFunctor
5556
{
@@ -59,7 +60,8 @@ template <typename argT, typename resT> struct CosFunctor
5960
// constant value, if constant
6061
// constexpr resT constant_value = resT{};
6162
// is function defined for sycl::vec
62-
using supports_vec = typename std::false_type;
63+
using supports_vec = typename std::negation<
64+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6365
// do both argTy and resTy support sugroup store/load operation
6466
using supports_sg_loadstore = typename std::negation<
6567
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -165,6 +167,20 @@ template <typename argT, typename resT> struct CosFunctor
165167
return std::cos(in);
166168
}
167169
}
170+
171+
template <int vec_sz>
172+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
173+
{
174+
auto const &res_vec = sycl::cos(in);
175+
using deducedT = typename std::remove_cv_t<
176+
std::remove_reference_t<decltype(res_vec)>>::element_type;
177+
if constexpr (std::is_same_v<resT, deducedT>) {
178+
return res_vec;
179+
}
180+
else {
181+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
182+
}
183+
}
168184
};
169185

170186
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252

5353
using dpctl::tensor::type_utils::is_complex;
54+
using dpctl::tensor::type_utils::vec_cast;
5455

5556
template <typename argT, typename resT> struct Expm1Functor
5657
{
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Expm1Functor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -132,6 +134,20 @@ template <typename argT, typename resT> struct Expm1Functor
132134
return std::expm1(in);
133135
}
134136
}
137+
138+
template <int vec_sz>
139+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
140+
{
141+
auto const &res_vec = sycl::expm1(in);
142+
using deducedT = typename std::remove_cv_t<
143+
std::remove_reference_t<decltype(res_vec)>>::element_type;
144+
if constexpr (std::is_same_v<resT, deducedT>) {
145+
return res_vec;
146+
}
147+
else {
148+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
149+
}
150+
}
135151
};
136152

137153
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace py = pybind11;
5151
namespace td_ns = dpctl::tensor::type_dispatch;
5252

5353
using dpctl::tensor::type_utils::is_complex;
54+
using dpctl::tensor::type_utils::vec_cast;
5455

5556
template <typename argT, typename resT> struct LogFunctor
5657
{
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct LogFunctor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -79,6 +81,20 @@ template <typename argT, typename resT> struct LogFunctor
7981
return std::log(in);
8082
}
8183
}
84+
85+
template <int vec_sz>
86+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
87+
{
88+
auto const &res_vec = sycl::log(in);
89+
using deducedT = typename std::remove_cv_t<
90+
std::remove_reference_t<decltype(res_vec)>>::element_type;
91+
if constexpr (std::is_same_v<resT, deducedT>) {
92+
return res_vec;
93+
}
94+
else {
95+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
96+
}
97+
}
8298
};
8399

84100
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/log1p.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ namespace py = pybind11;
5050
namespace td_ns = dpctl::tensor::type_dispatch;
5151

5252
using dpctl::tensor::type_utils::is_complex;
53+
using dpctl::tensor::type_utils::vec_cast;
5354

5455
// TODO: evaluate precision against alternatives
5556
template <typename argT, typename resT> struct Log1pFunctor
@@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Log1pFunctor
6061
// constant value, if constant
6162
// constexpr resT constant_value = resT{};
6263
// is function defined for sycl::vec
63-
using supports_vec = typename std::false_type;
64+
using supports_vec = typename std::negation<
65+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6466
// do both argTy and resTy support sugroup store/load operation
6567
using supports_sg_loadstore = typename std::negation<
6668
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -99,6 +101,20 @@ template <typename argT, typename resT> struct Log1pFunctor
99101
return std::log1p(in);
100102
}
101103
}
104+
105+
template <int vec_sz>
106+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
107+
{
108+
auto const &res_vec = sycl::log1p(in);
109+
using deducedT = typename std::remove_cv_t<
110+
std::remove_reference_t<decltype(res_vec)>>::element_type;
111+
if constexpr (std::is_same_v<resT, deducedT>) {
112+
return res_vec;
113+
}
114+
else {
115+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
116+
}
117+
}
102118
};
103119

104120
template <typename argTy,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sqrt.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ namespace py = pybind11;
5353
namespace td_ns = dpctl::tensor::type_dispatch;
5454

5555
using dpctl::tensor::type_utils::is_complex;
56+
using dpctl::tensor::type_utils::vec_cast;
5657

5758
template <typename argT, typename resT> struct SqrtFunctor
5859
{
@@ -62,7 +63,8 @@ template <typename argT, typename resT> struct SqrtFunctor
6263
// constant value, if constant
6364
// constexpr resT constant_value = resT{};
6465
// is function defined for sycl::vec
65-
using supports_vec = typename std::false_type;
66+
using supports_vec = typename std::negation<
67+
std::disjunction<is_complex<resT>, is_complex<argT>>>;
6668
// do both argTy and resTy support sugroup store/load operation
6769
using supports_sg_loadstore = typename std::negation<
6870
std::disjunction<is_complex<resT>, is_complex<argT>>>;
@@ -95,6 +97,20 @@ template <typename argT, typename resT> struct SqrtFunctor
9597
}
9698
}
9799

100+
template <int vec_sz>
101+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
102+
{
103+
auto const &res_vec = sycl::sqrt(in);
104+
using deducedT = typename std::remove_cv_t<
105+
std::remove_reference_t<decltype(res_vec)>>::element_type;
106+
if constexpr (std::is_same_v<resT, deducedT>) {
107+
return res_vec;
108+
}
109+
else {
110+
return vec_cast<resT, deducedT, vec_sz>(res_vec);
111+
}
112+
}
113+
98114
private:
99115
template <typename T> std::complex<T> csqrt(std::complex<T> const &z) const
100116
{

0 commit comments

Comments
 (0)