diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index f9b512723..2adb3c577 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -37,28 +37,43 @@ void impl_func_cum_ops( } } -Tensor& XPUNativeFunctions::cumsum_out( +static void cum_ops_meta( + const char* name, const Tensor& self, int64_t dim, - c10::optional dtype, + std::optional dtype, Tensor& result) { // Checking whether 'dim' is valid. maybe_wrap_dim(dim, self.dim()); ScalarType out_dtype; - if (!result.defined()) { - auto is_integral = - at::isIntegralType(self.scalar_type(), /*includeBool=*/true); - out_dtype = - dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); - result = at::empty_strided( - self.sizes(), self.strides(), self.options().dtype(out_dtype)); + if (result.defined()) { + out_dtype = dtype.value_or(result.scalar_type()); + at::xpu::resize_out( + result, + self.sizes(), + {}, + self.options().dtype(out_dtype)); } else { - at::native::resize_output(result, self.sizes()); - result.as_strided_(self.sizes(), self.strides()); + auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true); + out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type()); + result = at::xpu::create_out( + self.sizes(), + {}, + self.options().dtype(out_dtype)); } + namedinference::propagate_names(result, self); +} + +Tensor& XPUNativeFunctions::cumsum_out( + const Tensor& self, + int64_t dim, + c10::optional dtype, + Tensor& result) { + cum_ops_meta("cumsum", self, dim, dtype, result); + impl_func_cum_ops(self, dim, result, at::native::xpu::cumsum_kernel); return result; } @@ -68,14 +83,40 @@ Tensor XPUNativeFunctions::cumsum( int64_t dim, c10::optional dtype) { Tensor result; - return cumsum_out(self, dim, dtype, result); + return XPUNativeFunctions::cumsum_out(self, dim, dtype, result); } Tensor& XPUNativeFunctions::cumsum_( Tensor& self, int64_t dim, c10::optional dtype) { - return cumsum_out(self, dim, dtype, self); + return XPUNativeFunctions::cumsum_out(self, dim, dtype, self); +} + +Tensor& XPUNativeFunctions::cumprod_out( + const Tensor& self, + int64_t dim, + c10::optional dtype, + Tensor& result) { + cum_ops_meta("cumprod", self, dim, dtype, result); + + impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel); + return result; +} + +Tensor XPUNativeFunctions::cumprod( + const Tensor& self, + int64_t dim, + c10::optional dtype) { + Tensor result; + return XPUNativeFunctions::cumprod_out(self, dim, dtype, result); +} + +Tensor& XPUNativeFunctions::cumprod_( + Tensor& self, + int64_t dim, + c10::optional dtype) { + return XPUNativeFunctions::cumprod_out(self, dim, dtype, self); } static ScalarType infer_dtype_from_optional( diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 4c2440a7f..43e3d7649 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -178,7 +178,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_ctc_loss_backward", "_cummax_helper", "_cummin_helper", - "cumprod.out", "digamma.out", "dot", "_efficient_attention_forward", diff --git a/src/ATen/native/xpu/sycl/CumprodKernel.cpp b/src/ATen/native/xpu/sycl/CumprodKernel.cpp new file mode 100644 index 000000000..6c129183d --- /dev/null +++ b/src/ATen/native/xpu/sycl/CumprodKernel.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include + +namespace at::native::xpu { + +void launch_cumprod_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, + ScalarType::BFloat16, + self.scalar_type(), + "cumprod_xpu", + [&]() { + scalar_t init = 1; + scan( + result, self, dim, init, std::multiplies()); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/CumprodKernel.h b/src/ATen/native/xpu/sycl/CumprodKernel.h new file mode 100644 index 000000000..71e8c7693 --- /dev/null +++ b/src/ATen/native/xpu/sycl/CumprodKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void launch_cumprod_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/CumsumKernel.cpp b/src/ATen/native/xpu/sycl/CumsumKernel.cpp new file mode 100644 index 000000000..8a89b2231 --- /dev/null +++ b/src/ATen/native/xpu/sycl/CumsumKernel.cpp @@ -0,0 +1,24 @@ +#include +#include + +#include + +namespace at::native::xpu { + +void launch_cumsum_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim) { + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( + ScalarType::Half, + ScalarType::BFloat16, + self.scalar_type(), + "cumsum_xpu", + [&]() { + scalar_t init = 0; + scan( + result, self, dim, init, std::plus()); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/CumsumKernel.h b/src/ATen/native/xpu/sycl/CumsumKernel.h new file mode 100644 index 000000000..79c299608 --- /dev/null +++ b/src/ATen/native/xpu/sycl/CumsumKernel.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void launch_cumsum_kernel( + const Tensor& result, + const Tensor& self, + int64_t dim); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScanKernels.cpp b/src/ATen/native/xpu/sycl/ScanKernels.cpp index d39710a3d..ad97dc4b4 100644 --- a/src/ATen/native/xpu/sycl/ScanKernels.cpp +++ b/src/ATen/native/xpu/sycl/ScanKernels.cpp @@ -13,6 +13,9 @@ #include #endif +#include +#include + namespace at::native::xpu { static c10::MaybeOwned contiguous_out_arg(const Tensor& tensor) { @@ -26,19 +29,21 @@ static c10::MaybeOwned contiguous_out_arg(const Tensor& tensor) { void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) { auto result_ = contiguous_out_arg(result); - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - ScalarType::Half, - ScalarType::BFloat16, - self.scalar_type(), - "cumsum_xpu", - [&]() { - scalar_t init = 0; - scan( - *result_, self, dim, init, std::plus()); - }); + launch_cumsum_kernel(*result_, self, dim); + + if (!result.is_same(*result_)) { + result.copy_(*result_); + } +} + +void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim) { + auto result_ = contiguous_out_arg(result); + + launch_cumprod_kernel(*result_, self, dim); if (!result.is_same(*result_)) { result.copy_(*result_); } } + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScanKernels.h b/src/ATen/native/xpu/sycl/ScanKernels.h index 426d5b441..af632fa83 100644 --- a/src/ATen/native/xpu/sycl/ScanKernels.h +++ b/src/ATen/native/xpu/sycl/ScanKernels.h @@ -5,4 +5,6 @@ namespace at::native::xpu { void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim); +void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index aea77c9b8..e481437e2 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -58,6 +58,7 @@ "clamp_min", "clone", "copy", + "cumprod" "cumsum", "eq", "fill", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 9fbcf0ba7..b40c5afe7 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -15,6 +15,9 @@ supported: - cumsum - cumsum.out - cumsum_ + - cumprod + - cumprod.out + - cumprod_ - sub.Tensor - sub_.Tensor - sub.out