Skip to content

Commit

Permalink
add nansum
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel committed Jul 11, 2024
1 parent 0253fb9 commit 1c79c13
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 14 deletions.
34 changes: 34 additions & 0 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,4 +803,38 @@ Tensor XPUNativeFunctions::amin(
return out;
}

Tensor& XPUNativeFunctions::nansum_out(
const Tensor& self,
at::OptionalIntArrayRef dim,
bool keepdim,
optional<ScalarType> opt_dtype,
Tensor& result) {
// For integral types, use existing sum as
// integral types don't have `Nan`.
if (c10::isIntegralType(self.scalar_type(), true)) {
return at::sum_out(result, self, dim, keepdim, opt_dtype);
}

auto out_dtype = infer_dtype_from_optional(self, opt_dtype, result);
result = resize_reduction(result, self, dim, keepdim, out_dtype);
auto iter = meta::make_reduction_from_out_ty(
self, result, dim, keepdim, result.scalar_type());

if (iter.numel() == 0) {
result = result.zero_();
} else {
native::xpu::nansum_kernel(iter);
}
return result;
}

Tensor XPUNativeFunctions::nansum(
const Tensor& self,
at::OptionalIntArrayRef dim,
bool keepdim,
std::optional<ScalarType> opt_dtype) {
Tensor result;
return XPUNativeFunctions::nansum_out(self, dim, keepdim, opt_dtype, result);
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"multinomial",
"nanmedian",
"nanmedian.dim_values",
"nansum",
"nan_to_num.out",
"nextafter.out",
"norm.out",
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/ReduceOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ void mean_kernel(TensorIterator& iter);

void sum_kernel(TensorIterator& iter);

void nansum_kernel(TensorIterator& iter);

void std_var_kernel(TensorIterator& iter, double correction, bool take_sqrt);

} // namespace at::native::xpu
84 changes: 71 additions & 13 deletions src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <ATen/Dispatch.h>
#include <ATen/OpMathType.h>
#include <ATen/native/SharedReduceOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <ATen/native/xpu/sycl/Reduce.h>
Expand All @@ -8,6 +9,36 @@ namespace at {
namespace native {
namespace xpu {

// The function `reduce_dispatch` below dispatches to the kernel based
// on the type of `iter`. It takes care of the common logic
// for handling Half-Precision floating types.
// Otherwise the functor `op` is called to dispatch to the kernel
// of relevant type.
//
// Note: Functor `op` should take care of all the types to be supported
// except for `at::Half` and `at::BFloat16`.
template <
template <
typename scalar_t,
typename acc_t = scalar_t,
typename out_t = scalar_t>
typename OpFunctor,
typename GeneralDispatcher>
static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
if (iter.dtype() == kHalf) {
return OpFunctor<at::Half, float>{}(iter);
} else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return OpFunctor<at::Half, float, float>{}(iter);
} else if (iter.dtype() == kBFloat16) {
return OpFunctor<at::BFloat16, float>{}(iter);
} else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return OpFunctor<at::BFloat16, float, float>{}(iter);
}
op(iter);
}

template <typename acc_t>
struct SumFunctor {
inline acc_t operator()(acc_t a, acc_t b) const {
Expand Down Expand Up @@ -36,22 +67,49 @@ struct sum_functor {
};

void sum_kernel(TensorIterator& iter) {
if (iter.dtype() == kHalf) {
return sum_functor<at::Half, float>{}(iter);
} else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return sum_functor<at::Half, float, float>{}(iter);
} else if (iter.dtype() == kBFloat16) {
return sum_functor<at::BFloat16, float>{}(iter);
} else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return sum_functor<at::BFloat16, float, float>{}(iter);
auto general_dispatcher = [](TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kBool, kComplexHalf, iter.dtype(), "sum_xpu", [&]() {
sum_functor<scalar_t>{}(iter);
});
};
reduce_dispatch<sum_functor>(iter, general_dispatcher);
}

template <
typename scalar_t,
typename acc_t = scalar_t,
typename out_t = scalar_t>
struct nansum_functor {
void operator()(TensorIterator& iter) {
gpu_reduce_kernel<scalar_t, out_t>(
iter, at::native::NanSumOps<acc_t, out_t>{});
}
};

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
kBool, kComplexHalf, iter.dtype(), "sum_xpu", [&]() {
sum_functor<scalar_t>{}(iter);
template <typename scalar_t>
struct nansum_functor_complex {
void operator()(TensorIterator& iter) {
using acc_t = at::opmath_type<scalar_t>;
gpu_reduce_kernel<scalar_t, acc_t>(
iter, at::native::NanSumOps<acc_t, acc_t>{});
}
};

void nansum_kernel(TensorIterator& iter) {
auto general_dispatcher = [](TensorIterator& iter) {
auto dtype = iter.dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "nansum_xpu", [&]() {
nansum_functor_complex<scalar_t>{}(iter);
});
} else {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nansum_xpu", [&]() {
nansum_functor<scalar_t>{}(iter);
});
}
};
reduce_dispatch<nansum_functor>(iter, general_dispatcher);
}

} // namespace xpu
Expand Down
1 change: 1 addition & 0 deletions test/xpu/extended/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"test_compare_cpu_div_trunc_rounding_xpu_float16",
"test_compare_cpu_div_trunc_rounding_xpu_bfloat16",
"test_compare_cpu_addr_xpu_float16",
"test_compare_cpu_nansum_xpu_bfloat16",

# CUDA does not support the data type either
"test_compare_cpu_native_dropout_backward_xpu_bool",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"sin",
"sqrt",
"sum",
"nansum",
"amin",
"amax",
"std",
Expand Down
2 changes: 2 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ supported:
- min.dim_min
- sum.dim_IntList
- sum.IntList_out
- nansum
- nansum.out
- mean.out
- mean.dim
- std.correction
Expand Down

0 comments on commit 1c79c13

Please sign in to comment.