Skip to content

Commit

Permalink
Merge branch 'main' into chao/check
Browse files Browse the repository at this point in the history
  • Loading branch information
Chao1Han authored Aug 2, 2024
2 parents 8f2dc0a + fb8e6e9 commit 5b2c82e
Show file tree
Hide file tree
Showing 24 changed files with 848 additions and 76 deletions.
65 changes: 65 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,71 @@ Tensor& XPUNativeFunctions::sigmoid_out(const Tensor& self, Tensor& out) {
return out;
}

Tensor XPUNativeFunctions::sign(const Tensor& self) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::sign_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::sign_(Tensor& self) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
TensorIterator iter;
iter.build_borrowing_unary_op(self, self);
native::xpu::sign_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::sign_out(const Tensor& self, Tensor& out) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::sign_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::signbit(const Tensor& self) {
TORCH_CHECK(
!self.is_complex(), "signbit is not implemented for complex tensors.");

Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_force_boolean_op(out, self);

if (self.dtype() == at::kBool) {
iter.output().fill_(false);
} else {
native::xpu::signbit_kernel(iter);
}
return iter.output();
}

Tensor& XPUNativeFunctions::signbit_out(const Tensor& self, Tensor& out) {
TORCH_CHECK(
!self.is_complex(), "signbit is not implemented for complex tensors.");
TORCH_CHECK(
out.dtype() == at::kBool,
"signbit does not support non-boolean outputs.");

TensorIterator iter;
iter.build_borrowing_unary_force_boolean_op(out, self);

if (self.dtype() == at::kBool) {
out.fill_(false);
} else {
native::xpu::signbit_kernel(iter);
}
return out;
}

Tensor& XPUNativeFunctions::logit_out(
const Tensor& self,
std::optional<double> eps,
Expand Down
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_scaled_mm",
"segment_reduce",
"_segment_reduce_backward",
"signbit.out",
"sign.out",
"sinc.out",
"special_airy_ai.out",
"special_bessel_j0.out",
Expand Down
18 changes: 10 additions & 8 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ static inline void group_reduce(
// uint32_t SIMD = sg.get_local_range()[0];
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(
val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
}
if (sub_group_num == 1) {
if (lane_id == 0) {
Expand Down Expand Up @@ -294,7 +295,8 @@ static inline void group_reduce(
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(
val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
if (i >= ((sub_group_num + 1) >> 1))
break;
}
Expand Down Expand Up @@ -450,10 +452,10 @@ struct BatchNormCollectStatisticsKernelFunctor
// one value per subgroup
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
stat_accscalar_t o_avg = sg.shuffle_xor(avg, i);
int o_n = sg.shuffle_xor(n, i);
stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i);
int o_n = sycl::permute_group_by_xor(sg, n, i);
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n);
var_n += sg.shuffle_xor(var_n, i) +
var_n += sycl::permute_group_by_xor(sg, var_n, i) +
(avg - o_avg) * (avg - o_avg) * n * o_n * factor;
avg = (n * avg + o_n * o_avg) * factor;
n += o_n;
Expand Down Expand Up @@ -481,10 +483,10 @@ struct BatchNormCollectStatisticsKernelFunctor
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
stat_accscalar_t o_avg = sg.shuffle_xor(avg, i);
int o_n = sg.shuffle_xor(n, i);
stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i);
int o_n = sycl::permute_group_by_xor(sg, n, i);
stat_accscalar_t factor = 1.0f / fmaxf(1.0f, n + o_n);
var_n += sg.shuffle_xor(var_n, i) +
var_n += sycl::permute_group_by_xor(sg, var_n, i) +
(avg - o_avg) * (avg - o_avg) * n * o_n * factor;
avg = (n * avg + o_n * o_avg) * factor;
n += o_n;
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/DistanceKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ scalar_t subgroup_reduce_agg_without_broadcast_impl(

#pragma unroll
for (int offset = (SG_SIZE >> 1); offset > 0; offset >>= 1) {
F::agg(value, sg.shuffle_down(value, offset));
F::agg(value, sycl::shift_group_left(sg, value, offset));
}
return value;
}
Expand Down
20 changes: 10 additions & 10 deletions src/ATen/native/xpu/sycl/GroupNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/GroupReduceUtils.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <ATen/native/xpu/sycl/SharedReduceOps.h>
#include <comm/MemoryFormat.h>
#include <comm/XPUMathCompat.h>

Expand All @@ -18,23 +18,23 @@ template <
typename index_t,
typename res_t>
struct WelfordOpsXPU
: public at::native::WelfordOps<scalar_t, acc_scalar_t, index_t, res_t> {
: public WelfordOps<scalar_t, acc_scalar_t, index_t, res_t> {
sycl::nd_item<1>& item;

public:
using acc_t = typename at::native::
WelfordOps<scalar_t, acc_scalar_t, index_t, res_t>::acc_t;
using acc_t =
typename WelfordOps<scalar_t, acc_scalar_t, index_t, res_t>::acc_t;
inline acc_t shfl_down(acc_t acc, int offset) const {
auto sg = item.get_sub_group();
return {
sg.shuffle_down(acc.mean, offset),
sg.shuffle_down(acc.m2, offset),
sg.shuffle_down(acc.n, offset),
sg.shuffle_down(acc.nf, offset)};
sycl::shift_group_left(sg, acc.mean, offset),
sycl::shift_group_left(sg, acc.m2, offset),
sycl::shift_group_left(sg, acc.n, offset),
sycl::shift_group_left(sg, acc.nf, offset)};
}

WelfordOpsXPU(acc_scalar_t correction, bool take_sqrt, sycl::nd_item<1>& item)
: at::native::WelfordOps<scalar_t, acc_scalar_t, index_t, res_t>(
: WelfordOps<scalar_t, acc_scalar_t, index_t, res_t>(
correction,
take_sqrt),
item(item) {}
Expand All @@ -43,7 +43,7 @@ struct WelfordOpsXPU
template <typename T, int SIMD>
struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
using T_ACC = acc_type_device<T, kXPU>;
using WelfordType = at::native::WelfordData<T_ACC, int64_t>;
using WelfordType = WelfordData<T_ACC, int64_t>;
using WelfordOp =
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;

Expand Down
12 changes: 8 additions & 4 deletions src/ATen/native/xpu/sycl/Norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ static inline void norm_group_reduce(
// uint32_t SIMD = sg.get_local_range()[0];
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
sum1 = bin_op(sum1, static_cast<accscalar_t>(sg.shuffle_down(sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sg.shuffle_down(sum2, i)));
sum1 = bin_op(
sum1, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum1, i)));
sum2 = bin_op(
sum2, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum2, i)));
}
if (sub_group_num == 1) {
sum1 = sycl::group_broadcast(sg, sum1, 0);
Expand Down Expand Up @@ -73,8 +75,10 @@ static inline void norm_group_reduce(
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
sum1 = bin_op(sum1, static_cast<accscalar_t>(sg.shuffle_down(sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sg.shuffle_down(sum2, i)));
sum1 = bin_op(
sum1, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum1, i)));
sum2 = bin_op(
sum2, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum2, i)));
if (i >= ((sub_group_num + 1) >> 1))
break;
}
Expand Down
52 changes: 28 additions & 24 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <c10/macros/Macros.h>
#include <comm/DeviceProperties.h>
#include <comm/SYCLContext.h>
#include <comm/XPUPair.h>
#include <functional>
#include <iosfwd>
#include <type_traits>
Expand Down Expand Up @@ -50,7 +51,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
for (int offset = 1; offset < sg_size; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand All @@ -71,7 +72,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
for (int offset = 1; offset < sg_range; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand Down Expand Up @@ -132,7 +133,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_x_reduce(
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand Down Expand Up @@ -541,11 +542,11 @@ struct ReduceOp {
(const scalar_t*)((const char*)src + base_offsets1);
value = item_reduce<output_vec_size>(pos, input_slice);
}
// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// pair for half dtype, We temporarily workaround to do
// TODO: Currently, there are bugs with sycl::shift_group_left when the
// arg_t is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
constexpr bool is_pair =
std::is_same<std::pair<scalar_t, int64_t>, arg_t>::value;
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;

auto combine = [=](arg1_t value, arg2_t other) -> arg1_t {
return ops.combine(value, other);
Expand Down Expand Up @@ -832,8 +833,8 @@ struct ReduceOp {
return value_list[0];
}

// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// pair with half dtype, We temporarily workaround to do
// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair with half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_reduce_for_compound_dtype(
Expand All @@ -850,7 +851,7 @@ struct ReduceOp {
for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = ops.combine(value[i], other);
}
}
Expand All @@ -875,12 +876,13 @@ struct ReduceOp {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
// Shuffle down separately for first and second pair.
std::pair<typename arg_t::first_type, typename arg_t::second_type>
other = std::pair<
typename arg_t::first_type,
typename arg_t::second_type>(
sg.shuffle_down(value[i].first, offset),
sg.shuffle_down(value[i].second, offset));
at::xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>
other = at::xpu::pair<
typename arg_t::first_type,
typename arg_t::second_type>(
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
Expand All @@ -907,8 +909,8 @@ struct ReduceOp {
return value;
}

// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// pair for half dtype, We temporarily workaround to do
// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_x_reduce_for_compound_dtype(
Expand Down Expand Up @@ -947,11 +949,11 @@ struct ReduceOp {
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
std::pair<typename arg_t::first_type, typename arg_t::second_type>
other = std::
at::xpu::pair<typename arg_t::first_type, typename arg_t::second_type>
other = xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>(
sg.shuffle_down(value[i].first, offset),
sg.shuffle_down(value[i].second, offset));
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
Expand Down Expand Up @@ -1028,7 +1030,8 @@ struct ReduceOp {

// Currently implemented for max of two outputs
template <class T1, class T2>
void set_results(const std::pair<T1, T2> x, const index_t base_offset) const {
void set_results(const at::xpu::pair<T1, T2> x, const index_t base_offset)
const {
if (noutputs >= 1) {
auto res0 = (T1*)((char*)dst[0] + base_offset);
*res0 = x.first;
Expand Down Expand Up @@ -1121,9 +1124,10 @@ struct ReduceOp {
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
if (config.should_group_x_reduce()) {
// TODO: workaround because sg.shuffle_down will fail on `half` dtype.
// TODO: workaround because sycl::shift_group_left will fail on `half`
// dtype.
constexpr bool is_pair =
std::is_same<std::pair<scalar_t, int64_t>, arg_t>::value;
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;
if constexpr (is_pair) {
value = group_x_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared_memory);
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/ReduceAMinMaxKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <ATen/native/xpu/sycl/Reduce.h>
#include <ATen/native/xpu/sycl/SharedReduceOps.h>

namespace at::native::xpu {

Expand All @@ -12,7 +12,7 @@ void _min_max_values_kernel_xpu_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, scalar_t>(
iter,
MinMaxOps<scalar_t, scalar_t, int32_t>{},
std::pair<scalar_t, scalar_t>(
at::xpu::pair<scalar_t, scalar_t>(
at::numeric_limits<scalar_t>::upper_bound(),
at::numeric_limits<scalar_t>::lower_bound()));
}
Expand Down
5 changes: 3 additions & 2 deletions src/ATen/native/xpu/sycl/ReduceArgMaxKernel.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <ATen/native/xpu/sycl/Reduce.h>
#include <ATen/native/xpu/sycl/SharedReduceOps.h>

namespace at {
namespace native {
Expand All @@ -14,7 +14,8 @@ void argmax_kernel_impl(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, int64_t>(
iter,
ArgMaxOps<acc_t>{},
std::pair<acc_t, int64_t>(at::numeric_limits<acc_t>::lower_bound(), 0));
at::xpu::pair<acc_t, int64_t>(
at::numeric_limits<acc_t>::lower_bound(), 0));
};

void argmax_kernel(TensorIterator& iter) {
Expand Down
Loading

0 comments on commit 5b2c82e

Please sign in to comment.