diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index f9b512723..21e9359e3 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -803,4 +803,38 @@ Tensor XPUNativeFunctions::amin( return out; } +Tensor& XPUNativeFunctions::nansum_out( + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + optional opt_dtype, + Tensor& result) { + // For integral types, use existing sum as + // integral types don't have `Nan`. + if (c10::isIntegralType(self.scalar_type(), true)) { + return at::sum_out(result, self, dim, keepdim, opt_dtype); + } + + auto out_dtype = infer_dtype_from_optional(self, opt_dtype, result); + result = resize_reduction(result, self, dim, keepdim, out_dtype); + auto iter = meta::make_reduction_from_out_ty( + self, result, dim, keepdim, result.scalar_type()); + + if (iter.numel() == 0) { + result = result.zero_(); + } else { + native::xpu::nansum_kernel(iter); + } + return result; +} + +Tensor XPUNativeFunctions::nansum( + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + std::optional opt_dtype) { + Tensor result; + return XPUNativeFunctions::nansum_out(self, dim, keepdim, opt_dtype, result); +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 7bfdd6abd..1db72bb11 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -285,7 +285,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "multinomial", "nanmedian", "nanmedian.dim_values", - "nansum", "nan_to_num.out", "nextafter.out", "norm.out", diff --git a/src/ATen/native/xpu/sycl/ReduceOpsKernels.h b/src/ATen/native/xpu/sycl/ReduceOpsKernels.h index 955b055e9..432b8379a 100644 --- a/src/ATen/native/xpu/sycl/ReduceOpsKernels.h +++ b/src/ATen/native/xpu/sycl/ReduceOpsKernels.h @@ -14,6 +14,8 @@ void mean_kernel(TensorIterator& iter); void sum_kernel(TensorIterator& iter); +void nansum_kernel(TensorIterator& iter); + void std_var_kernel(TensorIterator& iter, double correction, bool take_sqrt); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp index 728a75582..09db67b6b 100644 --- a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -8,6 +9,36 @@ namespace at { namespace native { namespace xpu { +// The function `reduce_dispatch` below dispatches to the kernel based +// on the type of `iter`. It takes care of the common logic +// for handling Half-Precision floating types. +// Otherwise the functor `op` is called to dispatch to the kernel +// of relevant type. +// +// Note: Functor `op` should take care of all the types to be supported +// except for `at::Half` and `at::BFloat16`. +template < + template < + typename scalar_t, + typename acc_t = scalar_t, + typename out_t = scalar_t> + typename OpFunctor, + typename GeneralDispatcher> +static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) { + if (iter.dtype() == kHalf) { + return OpFunctor{}(iter); + } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { + // type promotion that does cast and reduction in a single kernel + return OpFunctor{}(iter); + } else if (iter.dtype() == kBFloat16) { + return OpFunctor{}(iter); + } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { + // type promotion that does cast and reduction in a single kernel + return OpFunctor{}(iter); + } + op(iter); +} + template struct SumFunctor { inline acc_t operator()(acc_t a, acc_t b) const { @@ -36,22 +67,49 @@ struct sum_functor { }; void sum_kernel(TensorIterator& iter) { - if (iter.dtype() == kHalf) { - return sum_functor{}(iter); - } else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) { - // type promotion that does cast and reduction in a single kernel - return sum_functor{}(iter); - } else if (iter.dtype() == kBFloat16) { - return sum_functor{}(iter); - } else if (iter.dtype(1) == kBFloat16 && iter.dtype() == kFloat) { - // type promotion that does cast and reduction in a single kernel - return sum_functor{}(iter); + auto general_dispatcher = [](TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kBool, kComplexHalf, iter.dtype(), "sum_xpu", [&]() { + sum_functor{}(iter); + }); + }; + reduce_dispatch(iter, general_dispatcher); +} + +template < + typename scalar_t, + typename acc_t = scalar_t, + typename out_t = scalar_t> +struct nansum_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, at::native::NanSumOps{}); } +}; - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - kBool, kComplexHalf, iter.dtype(), "sum_xpu", [&]() { - sum_functor{}(iter); +template +struct nansum_functor_complex { + void operator()(TensorIterator& iter) { + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, at::native::NanSumOps{}); + } +}; + +void nansum_kernel(TensorIterator& iter) { + auto general_dispatcher = [](TensorIterator& iter) { + auto dtype = iter.dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "nansum_xpu", [&]() { + nansum_functor_complex{}(iter); + }); + } else { + AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "nansum_xpu", [&]() { + nansum_functor{}(iter); }); + } + }; + reduce_dispatch(iter, general_dispatcher); } } // namespace xpu diff --git a/test/xpu/extended/run_test_with_skip.py b/test/xpu/extended/run_test_with_skip.py index 943d46465..0498f070f 100644 --- a/test/xpu/extended/run_test_with_skip.py +++ b/test/xpu/extended/run_test_with_skip.py @@ -31,6 +31,7 @@ "test_compare_cpu_div_trunc_rounding_xpu_float16", "test_compare_cpu_div_trunc_rounding_xpu_bfloat16", "test_compare_cpu_addr_xpu_float16", + "test_compare_cpu_nansum_xpu_bfloat16", # CUDA does not support the data type either "test_compare_cpu_native_dropout_backward_xpu_bool", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 35c29d96b..b7f6741f2 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -98,6 +98,7 @@ "sin", "sqrt", "sum", + "nansum", "amin", "amax", "std", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2ecc6790b..5b625f397 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -267,6 +267,8 @@ supported: - min.dim_min - sum.dim_IntList - sum.IntList_out + - nansum + - nansum.out - mean.out - mean.dim - std.correction