From 365b12c757539155e2160d940d12aab11cbb04b3 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Fri, 12 Jul 2024 23:32:33 +0800 Subject: [PATCH] Coding style --- src/ATen/native/xpu/ReduceOps.cpp | 108 +++++++++++------------ src/ATen/native/xpu/sycl/CumprodKernel.h | 5 +- src/ATen/native/xpu/sycl/CumsumKernel.h | 5 +- src/ATen/native/xpu/sycl/ScanKernels.cpp | 5 +- 4 files changed, 65 insertions(+), 58 deletions(-) diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index 9cfc7d4e0..2adb3c577 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -37,28 +37,43 @@ void impl_func_cum_ops( } } -Tensor& XPUNativeFunctions::cumsum_out( +static void cum_ops_meta( + const char* name, const Tensor& self, int64_t dim, - c10::optional dtype, + std::optional dtype, Tensor& result) { // Checking whether 'dim' is valid. maybe_wrap_dim(dim, self.dim()); ScalarType out_dtype; - if (!result.defined()) { - auto is_integral = - at::isIntegralType(self.scalar_type(), /*includeBool=*/true); - out_dtype = - dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); - result = at::empty_strided( - self.sizes(), self.strides(), self.options().dtype(out_dtype)); + if (result.defined()) { + out_dtype = dtype.value_or(result.scalar_type()); + at::xpu::resize_out( + result, + self.sizes(), + {}, + self.options().dtype(out_dtype)); } else { - at::native::resize_output(result, self.sizes()); - result.as_strided_(self.sizes(), self.strides()); + auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true); + out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); + result = at::xpu::create_out( + self.sizes(), + {}, + self.options().dtype(out_dtype)); } + namedinference::propagate_names(result, self); +} + +Tensor& XPUNativeFunctions::cumsum_out( + const Tensor& self, + int64_t dim, + c10::optional dtype, + Tensor& result) { + cum_ops_meta("cumsum", self, dim, dtype, result); + impl_func_cum_ops(self, dim, result, at::native::xpu::cumsum_kernel); return result; } @@ -68,14 +83,40 @@ Tensor XPUNativeFunctions::cumsum( int64_t dim, c10::optional dtype) { Tensor result; - return cumsum_out(self, dim, dtype, result); + return XPUNativeFunctions::cumsum_out(self, dim, dtype, result); } Tensor& XPUNativeFunctions::cumsum_( Tensor& self, int64_t dim, c10::optional dtype) { - return cumsum_out(self, dim, dtype, self); + return XPUNativeFunctions::cumsum_out(self, dim, dtype, self); +} + +Tensor& XPUNativeFunctions::cumprod_out( + const Tensor& self, + int64_t dim, + c10::optional dtype, + Tensor& result) { + cum_ops_meta("cumprod", self, dim, dtype, result); + + impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel); + return result; +} + +Tensor XPUNativeFunctions::cumprod( + const Tensor& self, + int64_t dim, + c10::optional dtype) { + Tensor result; + return XPUNativeFunctions::cumprod_out(self, dim, dtype, result); +} + +Tensor& XPUNativeFunctions::cumprod_( + Tensor& self, + int64_t dim, + c10::optional dtype) { + return XPUNativeFunctions::cumprod_out(self, dim, dtype, self); } static ScalarType infer_dtype_from_optional( @@ -803,45 +844,4 @@ Tensor XPUNativeFunctions::amin( return out; } -Tensor& XPUNativeFunctions::cumprod_out( - const Tensor& self, - int64_t dim, - c10::optional dtype, - Tensor& result) { - // Checking whether 'dim' is valid. - maybe_wrap_dim(dim, self.dim()); - - ScalarType out_dtype; - - if (!result.defined()) { - auto is_integral = - at::isIntegralType(self.scalar_type(), /*includeBool=*/true); - out_dtype = - dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); - result = at::empty_strided( - self.sizes(), self.strides(), self.options().dtype(out_dtype)); - } else { - at::native::resize_output(result, self.sizes()); - result.as_strided_(self.sizes(), self.strides()); - } - - impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel); - return result; -} - -Tensor XPUNativeFunctions::cumprod( - const Tensor& self, - int64_t dim, - c10::optional dtype) { - Tensor result; - return XPUNativeFunctions::cumprod_out(self, dim, dtype, result); -} - -Tensor& XPUNativeFunctions::cumprod_( - Tensor& self, - int64_t dim, - c10::optional dtype) { - return XPUNativeFunctions::cumprod_out(self, dim, dtype, self); -} - } // namespace at diff --git a/src/ATen/native/xpu/sycl/CumprodKernel.h b/src/ATen/native/xpu/sycl/CumprodKernel.h index a6a48ff92..71e8c7693 100644 --- a/src/ATen/native/xpu/sycl/CumprodKernel.h +++ b/src/ATen/native/xpu/sycl/CumprodKernel.h @@ -4,6 +4,9 @@ namespace at::native::xpu { -void cumprod_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim); +void launch_cumprod_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/CumsumKernel.h b/src/ATen/native/xpu/sycl/CumsumKernel.h index f74b84984..79c299608 100644 --- a/src/ATen/native/xpu/sycl/CumsumKernel.h +++ b/src/ATen/native/xpu/sycl/CumsumKernel.h @@ -4,6 +4,9 @@ namespace at::native::xpu { -void cumsum_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim); +void launch_cumsum_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim); } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScanKernels.cpp b/src/ATen/native/xpu/sycl/ScanKernels.cpp index c5a694e72..ad97dc4b4 100644 --- a/src/ATen/native/xpu/sycl/ScanKernels.cpp +++ b/src/ATen/native/xpu/sycl/ScanKernels.cpp @@ -29,7 +29,7 @@ static c10::MaybeOwned contiguous_out_arg(const Tensor& tensor) { void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) { auto result_ = contiguous_out_arg(result); - cumsum_kernel_impl(*result_, self, dim); + launch_cumsum_kernel(*result_, self, dim); if (!result.is_same(*result_)) { result.copy_(*result_); @@ -39,10 +39,11 @@ void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) { void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim) { auto result_ = contiguous_out_arg(result); - cumprod_kernel_impl(*result_, self, dim); + launch_cumprod_kernel(*result_, self, dim); if (!result.is_same(*result_)) { result.copy_(*result_); } } + } // namespace at::native::xpu