From 2e064efc2303eaf193836fdef67c50ce00f920a4 Mon Sep 17 00:00:00 2001 From: yucai Date: Thu, 18 Jul 2024 08:34:46 +0000 Subject: [PATCH] add prod --- src/ATen/native/xpu/ReduceOps.cpp | 68 ++++++++++++++++--- src/ATen/native/xpu/XPUFallback.template | 2 - src/ATen/native/xpu/sycl/ReduceOpsKernels.h | 2 + .../native/xpu/sycl/ReduceSumProdKernels.cpp | 50 ++++++++++++++ test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 3 + 6 files changed, 114 insertions(+), 12 deletions(-) diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index 2adb3c577..c984ad1d6 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -51,17 +51,14 @@ static void cum_ops_meta( if (result.defined()) { out_dtype = dtype.value_or(result.scalar_type()); at::xpu::resize_out( - result, - self.sizes(), - {}, - self.options().dtype(out_dtype)); + result, self.sizes(), {}, self.options().dtype(out_dtype)); } else { - auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true); - out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); - result = at::xpu::create_out( - self.sizes(), - {}, - self.options().dtype(out_dtype)); + auto is_integral = + at::isIntegralType(self.scalar_type(), /*includeBool=*/true); + out_dtype = + dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); + result = + at::xpu::create_out(self.sizes(), {}, self.options().dtype(out_dtype)); } namedinference::propagate_names(result, self); @@ -260,6 +257,57 @@ Tensor XPUNativeFunctions::mean( return out; } +Tensor& prod_meta( + const Tensor& self, + int64_t dim, + bool keepdim, + std::optional dtype, + Tensor& out) { + auto out_dtype = infer_dtype_from_optional(self, dtype, out); + out = resize_reduction(out, self, dim, keepdim, out_dtype); + return out; +} + +Tensor& XPUNativeFunctions::prod_out( + const Tensor& self, + int64_t dim, + bool keepdim, + std::optional dtype, + Tensor& result) { + result = prod_meta(self, dim, keepdim, dtype, result); + // device is not CPU + auto iter = at::meta::make_reduction_from_out_ty( + self, result, dim, keepdim, result.scalar_type()); + if (iter.numel() == 0) { + result.fill_(1); + } else { + native::xpu::prod_kernel(iter); + } + return result; +} + +Tensor XPUNativeFunctions::prod( + const Tensor& self, + std::optional opt_dtype) { + auto dtype = at::native::get_dtype_from_self(self, opt_dtype, true); + // auto shape = at::native::meta::get_reduction_shape(self, {}, false); + // Tensor out = at::empty(shape, self.options().dtype(dtype)); + Tensor out; + out = prod_meta(self, {}, false, dtype, out); + return out; +} + +Tensor XPUNativeFunctions::prod( + const Tensor& self, + int64_t dim, + bool keepdim, + std::optional dtype) { + Tensor out; + out = prod_meta(self, dim, keepdim, dtype, out); + out = XPUNativeFunctions::prod_out(self, dim, keepdim, dtype, out); + return out; +} + inline TensorIterator get_allany_iter( const Tensor& self, const Tensor& result, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 561142654..c16f36044 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -274,8 +274,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "polygamma.out", "_prelu_kernel", "_prelu_kernel_backward", - "prod", - "prod.int_out", "put_", "renorm.out", "repeat_interleave.Tensor", diff --git a/src/ATen/native/xpu/sycl/ReduceOpsKernels.h b/src/ATen/native/xpu/sycl/ReduceOpsKernels.h index 955b055e9..fcb3f3143 100644 --- a/src/ATen/native/xpu/sycl/ReduceOpsKernels.h +++ b/src/ATen/native/xpu/sycl/ReduceOpsKernels.h @@ -14,6 +14,8 @@ void mean_kernel(TensorIterator& iter); void sum_kernel(TensorIterator& iter); +void prod_kernel(TensorIterator& iter); + void std_var_kernel(TensorIterator& iter, double correction, bool take_sqrt); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp index 728a75582..e9136fc25 100644 --- a/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp +++ b/src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp @@ -54,6 +54,56 @@ void sum_kernel(TensorIterator& iter) { }); } +template +struct ProdFunctor { + inline acc_t operator()(acc_t a, acc_t b) const { + return a * b; + } +}; + +template <> +struct ProdFunctor { + inline bool operator()(bool a, bool b) const { + return a && b; + } +}; + +template < + typename scalar_t, + typename acc_t = scalar_t, + typename out_t = scalar_t> +struct prod_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper(ProdFunctor()), 1.); + } +}; + +template <> +struct prod_functor { + void operator()(TensorIterator& iter) { + gpu_reduce_kernel( + iter, func_wrapper(ProdFunctor()), 1); + } +}; + +template <> +struct prod_functor> { + void operator()(TensorIterator& iter) { + using scalar_t = c10::complex; + using acc_t = at::opmath_type; + gpu_reduce_kernel( + iter, func_wrapper(ProdFunctor()), acc_t{1.}); + } +}; + +void prod_kernel(TensorIterator& iter) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + kComplexHalf, kBool, iter.dtype(), "prod_xpu", [&]() { + prod_functor{}(iter); + }); +}; + } // namespace xpu } // namespace native } // namespace at diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index c44dad1d7..d82f78eed 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -124,6 +124,7 @@ "atanh", "sqrt", "sum", + "prod", "amin", "amax", "std", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 0d6d3c79f..46e69cfc6 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -296,6 +296,9 @@ supported: - min.dim_min - sum.dim_IntList - sum.IntList_out + - prod + - prod.int_out + - prod.dim_int - mean.out - mean.dim - std.correction