diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 10f7f0eec..25acd873a 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -263,7 +263,8 @@ static inline void group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op( + val, static_cast(sycl::shift_group_left(sg, val, i))); } if (sub_group_num == 1) { if (lane_id == 0) { @@ -294,7 +295,8 @@ static inline void group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op( + val, static_cast(sycl::shift_group_left(sg, val, i))); if (i >= ((sub_group_num + 1) >> 1)) break; } @@ -450,10 +452,10 @@ struct BatchNormCollectStatisticsKernelFunctor // one value per subgroup #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); - int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i); + int o_n = sycl::permute_group_by_xor(sg, n, i); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); - var_n += sg.shuffle_xor(var_n, i) + + var_n += sycl::permute_group_by_xor(sg, var_n, i) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; @@ -481,10 +483,10 @@ struct BatchNormCollectStatisticsKernelFunctor } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); - int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i); + int o_n = sycl::permute_group_by_xor(sg, n, i); stat_accscalar_t factor = 1.0f / fmaxf(1.0f, n + o_n); - var_n += sg.shuffle_xor(var_n, i) + + var_n += sycl::permute_group_by_xor(sg, var_n, i) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; diff --git a/src/ATen/native/xpu/sycl/DistanceKernels.cpp b/src/ATen/native/xpu/sycl/DistanceKernels.cpp index 8bd61bdd3..eb0f1f50e 100644 --- a/src/ATen/native/xpu/sycl/DistanceKernels.cpp +++ b/src/ATen/native/xpu/sycl/DistanceKernels.cpp @@ -120,7 +120,7 @@ scalar_t subgroup_reduce_agg_without_broadcast_impl( #pragma unroll for (int offset = (SG_SIZE >> 1); offset > 0; offset >>= 1) { - F::agg(value, sg.shuffle_down(value, offset)); + F::agg(value, sycl::shift_group_left(sg, value, offset)); } return value; } diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 622e99ffe..8dafdc8d2 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -3,10 +3,10 @@ #include #include #include -#include #include #include #include +#include #include #include @@ -18,23 +18,23 @@ template < typename index_t, typename res_t> struct WelfordOpsXPU - : public at::native::WelfordOps { + : public WelfordOps { sycl::nd_item<1>& item; public: - using acc_t = typename at::native:: - WelfordOps::acc_t; + using acc_t = + typename WelfordOps::acc_t; inline acc_t shfl_down(acc_t acc, int offset) const { auto sg = item.get_sub_group(); return { - sg.shuffle_down(acc.mean, offset), - sg.shuffle_down(acc.m2, offset), - sg.shuffle_down(acc.n, offset), - sg.shuffle_down(acc.nf, offset)}; + sycl::shift_group_left(sg, acc.mean, offset), + sycl::shift_group_left(sg, acc.m2, offset), + sycl::shift_group_left(sg, acc.n, offset), + sycl::shift_group_left(sg, acc.nf, offset)}; } WelfordOpsXPU(acc_scalar_t correction, bool take_sqrt, sycl::nd_item<1>& item) - : at::native::WelfordOps( + : WelfordOps( correction, take_sqrt), item(item) {} @@ -43,7 +43,7 @@ struct WelfordOpsXPU template struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using T_ACC = acc_type_device; - using WelfordType = at::native::WelfordData; + using WelfordType = WelfordData; using WelfordOp = WelfordOpsXPU>; diff --git a/src/ATen/native/xpu/sycl/Norm.h b/src/ATen/native/xpu/sycl/Norm.h index 9aee941cb..36d1282a3 100644 --- a/src/ATen/native/xpu/sycl/Norm.h +++ b/src/ATen/native/xpu/sycl/Norm.h @@ -39,8 +39,10 @@ static inline void norm_group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - sum1 = bin_op(sum1, static_cast(sg.shuffle_down(sum1, i))); - sum2 = bin_op(sum2, static_cast(sg.shuffle_down(sum2, i))); + sum1 = bin_op( + sum1, static_cast(sycl::shift_group_left(sg, sum1, i))); + sum2 = bin_op( + sum2, static_cast(sycl::shift_group_left(sg, sum2, i))); } if (sub_group_num == 1) { sum1 = sycl::group_broadcast(sg, sum1, 0); @@ -73,8 +75,10 @@ static inline void norm_group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - sum1 = bin_op(sum1, static_cast(sg.shuffle_down(sum1, i))); - sum2 = bin_op(sum2, static_cast(sg.shuffle_down(sum2, i))); + sum1 = bin_op( + sum1, static_cast(sycl::shift_group_left(sg, sum1, i))); + sum2 = bin_op( + sum2, static_cast(sycl::shift_group_left(sg, sum2, i))); if (i >= ((sub_group_num + 1) >> 1)) break; } diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 1be3a5e93..276b21175 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -50,7 +51,7 @@ inline at::detail::Array group_reduce( for (int offset = 1; offset < sg_size; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -71,7 +72,7 @@ inline at::detail::Array group_reduce( for (int offset = 1; offset < sg_range; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -132,7 +133,7 @@ inline at::detail::Array group_x_reduce( for (int offset = 1; offset < dim_x; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -541,11 +542,11 @@ struct ReduceOp { (const scalar_t*)((const char*)src + base_offsets1); value = item_reduce(pos, input_slice); } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a - // pair for half dtype, We temporarily workaround to do + // TODO: Currently, there are bugs with sycl::shift_group_left when the + // arg_t is a pair for half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. constexpr bool is_pair = - std::is_same, arg_t>::value; + std::is_same, arg_t>::value; auto combine = [=](arg1_t value, arg2_t other) -> arg1_t { return ops.combine(value, other); @@ -832,8 +833,8 @@ struct ReduceOp { return value_list[0]; } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a - // pair with half dtype, We temporarily workaround to do + // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t + // is a pair with half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. template at::detail::Array group_reduce_for_compound_dtype( @@ -850,7 +851,7 @@ struct ReduceOp { for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) { #pragma unroll(output_vec_size) for (int i = 0; i < output_vec_size; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = ops.combine(value[i], other); } } @@ -875,12 +876,13 @@ struct ReduceOp { #pragma unroll(output_vec_size) for (int i = 0; i < output_vec_size; ++i) { // Shuffle down separately for first and second pair. - std::pair - other = std::pair< - typename arg_t::first_type, - typename arg_t::second_type>( - sg.shuffle_down(value[i].first, offset), - sg.shuffle_down(value[i].second, offset)); + at::xpu:: + pair + other = at::xpu::pair< + typename arg_t::first_type, + typename arg_t::second_type>( + sycl::shift_group_left(sg, value[i].first, offset), + sycl::shift_group_left(sg, value[i].second, offset)); value[i] = ops.combine(value[i], other); } } @@ -907,8 +909,8 @@ struct ReduceOp { return value; } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a - // pair for half dtype, We temporarily workaround to do + // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t + // is a pair for half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. template at::detail::Array group_x_reduce_for_compound_dtype( @@ -947,11 +949,11 @@ struct ReduceOp { for (int offset = 1; offset < dim_x; offset <<= 1) { #pragma unroll(output_vec_size) for (int i = 0; i < output_vec_size; ++i) { - std::pair - other = std:: + at::xpu::pair + other = xpu:: pair( - sg.shuffle_down(value[i].first, offset), - sg.shuffle_down(value[i].second, offset)); + sycl::shift_group_left(sg, value[i].first, offset), + sycl::shift_group_left(sg, value[i].second, offset)); value[i] = ops.combine(value[i], other); } } @@ -1028,7 +1030,8 @@ struct ReduceOp { // Currently implemented for max of two outputs template - void set_results(const std::pair x, const index_t base_offset) const { + void set_results(const at::xpu::pair x, const index_t base_offset) + const { if (noutputs >= 1) { auto res0 = (T1*)((char*)dst[0] + base_offset); *res0 = x.first; @@ -1121,9 +1124,10 @@ struct ReduceOp { decltype(combine), output_vec_size>(pos, shared_memory, value, combine); if (config.should_group_x_reduce()) { - // TODO: workaround because sg.shuffle_down will fail on `half` dtype. + // TODO: workaround because sycl::shift_group_left will fail on `half` + // dtype. constexpr bool is_pair = - std::is_same, arg_t>::value; + std::is_same, arg_t>::value; if constexpr (is_pair) { value = group_x_reduce_for_compound_dtype( pos, value, shared_memory); diff --git a/src/ATen/native/xpu/sycl/ReduceAMinMaxKernel.cpp b/src/ATen/native/xpu/sycl/ReduceAMinMaxKernel.cpp index 04c847a4d..121761053 100644 --- a/src/ATen/native/xpu/sycl/ReduceAMinMaxKernel.cpp +++ b/src/ATen/native/xpu/sycl/ReduceAMinMaxKernel.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include namespace at::native::xpu { @@ -12,7 +12,7 @@ void _min_max_values_kernel_xpu_impl(TensorIterator& iter) { gpu_reduce_kernel( iter, MinMaxOps{}, - std::pair( + at::xpu::pair( at::numeric_limits::upper_bound(), at::numeric_limits::lower_bound())); } diff --git a/src/ATen/native/xpu/sycl/ReduceArgMaxKernel.cpp b/src/ATen/native/xpu/sycl/ReduceArgMaxKernel.cpp index c41ad0c04..1e18b2e5b 100644 --- a/src/ATen/native/xpu/sycl/ReduceArgMaxKernel.cpp +++ b/src/ATen/native/xpu/sycl/ReduceArgMaxKernel.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include namespace at { namespace native { @@ -14,7 +14,8 @@ void argmax_kernel_impl(TensorIterator& iter) { gpu_reduce_kernel( iter, ArgMaxOps{}, - std::pair(at::numeric_limits::lower_bound(), 0)); + at::xpu::pair( + at::numeric_limits::lower_bound(), 0)); }; void argmax_kernel(TensorIterator& iter) { diff --git a/src/ATen/native/xpu/sycl/ReduceArgMinKernel.cpp b/src/ATen/native/xpu/sycl/ReduceArgMinKernel.cpp index 2f6c38152..3c9f8453d 100644 --- a/src/ATen/native/xpu/sycl/ReduceArgMinKernel.cpp +++ b/src/ATen/native/xpu/sycl/ReduceArgMinKernel.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include namespace at::native::xpu { @@ -12,7 +12,8 @@ void argmin_kernel_impl(TensorIterator& iter) { gpu_reduce_kernel( iter, ArgMinOps{}, - std::pair(at::numeric_limits::upper_bound(), 0)); + at::xpu::pair( + at::numeric_limits::upper_bound(), 0)); }; void argmin_kernel(TensorIterator& iter) { diff --git a/src/ATen/native/xpu/sycl/ReduceMaxValuesKernels.cpp b/src/ATen/native/xpu/sycl/ReduceMaxValuesKernels.cpp index 0dfa7fd52..16095056c 100644 --- a/src/ATen/native/xpu/sycl/ReduceMaxValuesKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceMaxValuesKernels.cpp @@ -1,7 +1,7 @@ #include #include -#include #include +#include #include #include @@ -38,7 +38,7 @@ void max_kernel(TensorIterator& iter) { gpu_reduce_kernel( iter, MaxOps{}, - std::pair( + at::xpu::pair( at::numeric_limits::lower_bound(), 0)); }); } diff --git a/src/ATen/native/xpu/sycl/ReduceMinValuesKernels.cpp b/src/ATen/native/xpu/sycl/ReduceMinValuesKernels.cpp index fff7207c6..2a0ce889c 100644 --- a/src/ATen/native/xpu/sycl/ReduceMinValuesKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceMinValuesKernels.cpp @@ -1,6 +1,6 @@ #include -#include #include +#include #include #include @@ -37,7 +37,7 @@ void min_kernel(TensorIterator& iter) { gpu_reduce_kernel( iter, MinOps{}, - std::pair( + at::xpu::pair( at::numeric_limits::upper_bound(), 0)); }); } diff --git a/src/ATen/native/xpu/sycl/ReduceMomentKernels.cpp b/src/ATen/native/xpu/sycl/ReduceMomentKernels.cpp index 54ae01273..6d0e75680 100644 --- a/src/ATen/native/xpu/sycl/ReduceMomentKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceMomentKernels.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include namespace at { namespace native { @@ -18,7 +18,7 @@ void std_var_template( // This is necessary to lower register usage that leads to register spills. using accscalar_t = at::acc_type_device; using ops_t = - WelfordOps>; + WelfordOps>; ops_t ops(static_cast(correction_opt), take_sqrt); gpu_reduce_kernel(iter, ops, typename ops_t::acc_t{}); } diff --git a/src/ATen/native/xpu/sycl/ReduceNormKernel.cpp b/src/ATen/native/xpu/sycl/ReduceNormKernel.cpp index 4aac0cceb..658f2e21b 100644 --- a/src/ATen/native/xpu/sycl/ReduceNormKernel.cpp +++ b/src/ATen/native/xpu/sycl/ReduceNormKernel.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include diff --git a/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h index 35f1d54a5..8729f31dd 100644 --- a/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h +++ b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h @@ -6,7 +6,7 @@ template inline T GroupReduceSumSGSizeEqualstoNumSG(sg_t& sg, T val) { auto sg_size = sg.get_local_range()[0]; for (int offset = (sg_size >> 1); offset > 0; offset >>= 1) { - val += sg.shuffle_down(val, offset); + val += sycl::shift_group_left(sg, val, offset); } return val; } diff --git a/src/ATen/native/xpu/sycl/SharedReduceOps.h b/src/ATen/native/xpu/sycl/SharedReduceOps.h new file mode 100644 index 000000000..0e63cc6ed --- /dev/null +++ b/src/ATen/native/xpu/sycl/SharedReduceOps.h @@ -0,0 +1,416 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX(X, Y) max_impl(X, Y) +#define MIN(X, Y) min_impl(X, Y) + +#define device_sqrt std::sqrt +#define compat_pow std::pow + +namespace at { +namespace native { +namespace xpu { + +template +struct WelfordData { + scalar_t mean; + scalar_t m2; + index_t n; + scalar_t nf; + + WelfordData() : mean(0), m2(0), n(0), nf(0) {} + + WelfordData(scalar_t mean, scalar_t m2, index_t n, scalar_t nf) + : mean(mean), m2(m2), n(n), nf(nf) {} +}; + +template < + typename scalar_t, + typename acc_scalar_t, + typename index_t, + typename res_t> +struct WelfordOps { + acc_scalar_t correction; + bool take_sqrt; + + public: + using acc_t = WelfordData; + inline acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + // We accumulate n in index_t to avoid cumulative rounding error, but still + // need nf for use in combine where int32 may overflow. + index_t new_n = acc.n + 1; + acc_scalar_t new_nf = static_cast(new_n); + acc_scalar_t delta = data - acc.mean; + acc_scalar_t new_mean = acc.mean + delta / new_nf; + acc_scalar_t new_delta = data - new_mean; + return { + new_mean, + acc.m2 + delta * new_delta, + new_n, + new_nf, + }; + } + inline acc_t combine(acc_t a, acc_t b) const { + if (a.nf == 0) { + return b; + } + if (b.nf == 0) { + return a; + } + acc_scalar_t delta = b.mean - a.mean; + acc_scalar_t new_count = a.nf + b.nf; + acc_scalar_t nb_over_n = b.nf / new_count; + return { + a.mean + delta * nb_over_n, + a.m2 + b.m2 + delta * delta * a.nf * nb_over_n, + // setting acc.n as -1 since acc.n might not be able to represent the + // count correctly within its range, setting it to -1 to avoid confusion + -1, + new_count}; + } + inline res_t project(acc_t acc) const __ubsan_ignore_float_divide_by_zero__ { + const auto mean = static_cast(acc.mean); + const auto divisor = acc.nf > correction ? acc.nf - correction : 0; + const auto var = acc.m2 / divisor; + res_t results(take_sqrt ? device_sqrt(var) : var, mean); + return results; + } + + static C10_DEVICE acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + + WelfordOps(acc_scalar_t correction, bool take_sqrt) + : correction(correction), take_sqrt(take_sqrt) {} +}; + +template < + typename scalar_t, + typename acc_t = scalar_t, + typename factor_t = acc_t, + typename out_t = acc_t> +struct MeanOps { + factor_t factor; + + inline acc_t reduce(acc_t a, scalar_t b, int64_t /*idx*/) const { + return combine(a, static_cast(b)); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline out_t project(acc_t a) const { + return a * factor; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + + MeanOps(factor_t factor) : factor(factor) {} +}; + +// This accumulator template is used to calculate the minimum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct AbsMinOps { + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MIN(acc, static_cast(std::abs(data))); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return MIN(a, b); + } + + inline out_t project(acc_t a) const { + return a; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +// This accumulator template is used to calculate the maximum absolute value of +// a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct AbsMaxOps { + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return MAX(acc, static_cast(std::abs(data))); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return MAX(a, b); + } + + inline out_t project(acc_t a) const { + return a; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +// This accumulator template is used to calculate the norm of the absolute value +// of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct NormOps { + acc_t norm_; + + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + compat_pow(static_cast(std::abs(data)), norm_); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline out_t project(acc_t a) const { + return compat_pow(a, static_cast(1.0) / norm_); + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } + + NormOps(acc_t norm_) : norm_(norm_) {} +}; + +// This accumulator template is used to calculate the order zero norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct NormZeroOps { + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + + (data == static_cast(0) ? static_cast(0) + : static_cast(1)); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline out_t project(acc_t a) const { + return a; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +// This accumulator template is used to calculate the order one norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct NormOneOps { + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + return acc + static_cast(std::abs(data)); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline out_t project(acc_t a) const { + return a; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +template +struct AbsSwitch {}; + +template +inline acc_t abs_if_complex(scalar_t data, AbsSwitch) { + return static_cast(data); +} + +template +inline acc_t abs_if_complex(std::complex data, AbsSwitch) { + return static_cast(std::abs(data)); +} + +template +inline acc_t abs_if_complex(c10::complex data, AbsSwitch) { + return static_cast(std::abs(data)); +} + +// This accumulator template is used to calculate the order two norm of the +// absolute value of a set of numbers. +// `scalar_t` is the type of the input and `acc_t` is the type of the +// accumulated value. These types differ for complex number input support. +template +struct NormTwoOps { + inline acc_t reduce(acc_t acc, scalar_t data, int64_t /*idx*/) const { + acc_t data_ = abs_if_complex(data, AbsSwitch()); + return acc + data_ * data_; + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline out_t project(acc_t a) const { + return device_sqrt(a); + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +template +struct NanSumOps { + inline acc_t reduce(acc_t a, data_t b, int64_t /*idx*/) const { + return a + (at::_isnan(b) ? acc_t{0.} : acc_t{b}); + } + + inline acc_t combine(acc_t a, acc_t b) const { + return a + b; + } + + inline data_t project(acc_t a) const { + return data_t{a}; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +namespace detail { + +template +struct LessOrNan { + bool operator()(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else min(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a < b); + } +}; + +template +struct GreaterOrNan { + bool operator()(scalar_t a, scalar_t b, int64_t idx_a, int64_t idx_b) const { + // If (a == b), then choose the one with lower idx, else max(a, b) + if (at::_isnan(a)) { + if (at::_isnan(b)) { + return idx_a < idx_b; + } + return true; + } + return (a == b) ? idx_a < idx_b : (a > b); + } +}; + +template +struct MinMaxReductionOps { + using scalar_t = typename binary_function_traits::arg1_t; + using index_t = int64_t; + using arg_t = at::xpu::pair; + + static arg_t project(arg_t arg) { + return arg; + } + + static arg_t reduce(arg_t arg, scalar_t val, int64_t idx) { + return comp_t{}(arg.first, val, arg.second, idx) ? arg : arg_t(val, idx); + } + + static arg_t combine(arg_t a, arg_t b) { + return comp_t{}(a.first, b.first, a.second, b.second) ? a : b; + } + + static arg_t translate_idx(arg_t a, int64_t base_idx) { + return {a.first, a.second + base_idx}; + } +}; + +template +struct ArgReductionOps : public MinMaxReductionOps { + using typename MinMaxReductionOps::scalar_t; + using typename MinMaxReductionOps::index_t; + using typename MinMaxReductionOps::arg_t; + + static index_t project(arg_t arg) { + return arg.second; + } +}; + +} // namespace detail + +template +struct ArgMaxOps + : public detail::ArgReductionOps> {}; + +template +struct ArgMinOps : public detail::ArgReductionOps> { +}; + +template +struct MinOps : public detail::MinMaxReductionOps> { +}; + +template +struct MaxOps + : public detail::MinMaxReductionOps> {}; + +template +struct MinMaxOps { + using acc_t = at::xpu::pair; + inline acc_t reduce(acc_t acc, scalar_t data, index_t /*idx*/) const { + return combine(acc, {data, data}); + } + + inline acc_t combine(acc_t a, acc_t b) const { + auto min_val = + (at::_isnan(a.first) || a.first < b.first) ? a.first : b.first; + auto max_val = + (at::_isnan(a.second) || a.second > b.second) ? a.second : b.second; + + return {min_val, max_val}; + } + + inline acc_t project(acc_t acc) const { + return acc; + } + + static acc_t translate_idx(acc_t acc, int64_t /*base_idx*/) { + return acc; + } +}; + +} // namespace xpu +} // namespace native +} // namespace at + +#undef MAX +#undef MIN diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 8db72165a..78062e21a 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -40,7 +40,8 @@ static inline void softmax_group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op( + val, static_cast(sycl::shift_group_left(sg, val, i))); } if (sub_group_num == 1) { val = sycl::group_broadcast(sg, val, 0); @@ -68,7 +69,8 @@ static inline void softmax_group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op( + val, static_cast(sycl::shift_group_left(sg, val, i))); if (i >= ((sub_group_num + 1) >> 1)) break; } diff --git a/src/comm/XPUPair.h b/src/comm/XPUPair.h new file mode 100644 index 000000000..0c36f8625 --- /dev/null +++ b/src/comm/XPUPair.h @@ -0,0 +1,78 @@ +#pragma once + +#include + +namespace at::xpu { + +template +struct pair { + typedef T1 first_type; + typedef T2 second_type; + + first_type first; + second_type second; + + // default constructor + pair(void) : first(), second() {} + + inline pair(const T1& x, const T2& y) : first(x), second(y) {} + + template + inline pair(const pair& p) : first(p.first), second(p.second) {} + + template + pair(const std::pair& p) : first(p.first), second(p.second) {} +}; + +template +bool operator==(const pair& x, const pair& y) { + return x.first == y.first && x.second == y.second; +} + +template +inline bool operator<(const pair& x, const pair& y) { + return x.first < y.first || (!(y.first < x.first) && x.second < y.second); +} + +template +inline bool operator!=(const pair& x, const pair& y) { + return !(x == y); +} + +template +inline bool operator>(const pair& x, const pair& y) { + return y < x; +} + +template +bool operator<=(const pair& x, const pair& y) { + return !(y < x); +} + +template +bool operator>=(const pair& x, const pair& y) { + return !(x < y); +} + +template +inline pair make_pair(T1 x, T2 y) { + return pair(x, y); +} + +template +inline auto& get(pair& p) { + if constexpr (N == 0) + return p.first; + else + return p.second; +} + +template +inline const auto& get(const pair& p) { + if constexpr (N == 0) + return p.first; + else + return p.second; +} + +} // namespace at::xpu