Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::cumprod #522

Merged
merged 8 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 54 additions & 13 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,43 @@ void impl_func_cum_ops(
}
}

Tensor& XPUNativeFunctions::cumsum_out(
static void cum_ops_meta(
const char* name,
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
std::optional<ScalarType> dtype,
Tensor& result) {
// Checking whether 'dim' is valid.
maybe_wrap_dim(dim, self.dim());

ScalarType out_dtype;

if (!result.defined()) {
auto is_integral =
at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype =
dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::empty_strided(
self.sizes(), self.strides(), self.options().dtype(out_dtype));
if (result.defined()) {
out_dtype = dtype.value_or(result.scalar_type());
at::xpu::resize_out(
result,
self.sizes(),
{},
self.options().dtype(out_dtype));
} else {
at::native::resize_output(result, self.sizes());
result.as_strided_(self.sizes(), self.strides());
auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::xpu::create_out(
self.sizes(),
{},
self.options().dtype(out_dtype));
}

namedinference::propagate_names(result, self);
}

Tensor& XPUNativeFunctions::cumsum_out(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
Tensor& result) {
cum_ops_meta("cumsum", self, dim, dtype, result);

impl_func_cum_ops(self, dim, result, at::native::xpu::cumsum_kernel);
return result;
}
Expand All @@ -68,14 +83,40 @@ Tensor XPUNativeFunctions::cumsum(
int64_t dim,
c10::optional<ScalarType> dtype) {
Tensor result;
return cumsum_out(self, dim, dtype, result);
return XPUNativeFunctions::cumsum_out(self, dim, dtype, result);
}

Tensor& XPUNativeFunctions::cumsum_(
Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
return cumsum_out(self, dim, dtype, self);
return XPUNativeFunctions::cumsum_out(self, dim, dtype, self);
}

Tensor& XPUNativeFunctions::cumprod_out(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
Tensor& result) {
cum_ops_meta("cumprod", self, dim, dtype, result);

impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel);
return result;
}

Tensor XPUNativeFunctions::cumprod(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
Tensor result;
return XPUNativeFunctions::cumprod_out(self, dim, dtype, result);
}

Tensor& XPUNativeFunctions::cumprod_(
Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
return XPUNativeFunctions::cumprod_out(self, dim, dtype, self);
}

static ScalarType infer_dtype_from_optional(
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_ctc_loss_backward",
"_cummax_helper",
"_cummin_helper",
"cumprod.out",
"digamma.out",
"dot",
"_efficient_attention_forward",
Expand Down
24 changes: 24 additions & 0 deletions src/ATen/native/xpu/sycl/CumprodKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>

#include <ATen/native/xpu/sycl/ScanUtils.h>

namespace at::native::xpu {

void launch_cumprod_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
ScalarType::Half,
ScalarType::BFloat16,
self.scalar_type(),
"cumprod_xpu",
[&]() {
scalar_t init = 1;
scan<INCLUSIVE_TYPE, scalar_t, scalar_t>(
result, self, dim, init, std::multiplies<scalar_t>());
});
}

} // namespace at::native::xpu
12 changes: 12 additions & 0 deletions src/ATen/native/xpu/sycl/CumprodKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <ATen/core/TensorBase.h>

namespace at::native::xpu {

void launch_cumprod_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim);

} // namespace at::native::xpu
24 changes: 24 additions & 0 deletions src/ATen/native/xpu/sycl/CumsumKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>

#include <ATen/native/xpu/sycl/ScanUtils.h>

namespace at::native::xpu {

void launch_cumsum_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
ScalarType::Half,
ScalarType::BFloat16,
self.scalar_type(),
"cumsum_xpu",
[&]() {
scalar_t init = 0;
scan<INCLUSIVE_TYPE, scalar_t, scalar_t>(
result, self, dim, init, std::plus<scalar_t>());
});
}

} // namespace at::native::xpu
12 changes: 12 additions & 0 deletions src/ATen/native/xpu/sycl/CumsumKernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <ATen/core/TensorBase.h>

namespace at::native::xpu {

void launch_cumsum_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim);

} // namespace at::native::xpu
25 changes: 15 additions & 10 deletions src/ATen/native/xpu/sycl/ScanKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include <ATen/ops/empty_like.h>
#endif

#include <ATen/native/xpu/sycl/CumprodKernel.h>
#include <ATen/native/xpu/sycl/CumsumKernel.h>

namespace at::native::xpu {

static c10::MaybeOwned<Tensor> contiguous_out_arg(const Tensor& tensor) {
Expand All @@ -26,19 +29,21 @@ static c10::MaybeOwned<Tensor> contiguous_out_arg(const Tensor& tensor) {
void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
ScalarType::Half,
ScalarType::BFloat16,
self.scalar_type(),
"cumsum_xpu",
[&]() {
scalar_t init = 0;
scan<INCLUSIVE_TYPE, scalar_t, scalar_t>(
*result_, self, dim, init, std::plus<scalar_t>());
});
launch_cumsum_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
}

void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

launch_cumprod_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/ScanKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ namespace at::native::xpu {

void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim);

void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"clamp_min",
"clone",
"copy",
"cumprod"
"cumsum",
"eq",
"fill",
Expand Down
3 changes: 3 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ supported:
- cumsum
- cumsum.out
- cumsum_
- cumprod
- cumprod.out
- cumprod_
- sub.Tensor
- sub_.Tensor
- sub.out
Expand Down