Skip to content

Commit

Permalink
Coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyuan14 committed Jul 12, 2024
1 parent 6b999b9 commit 365b12c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 58 deletions.
108 changes: 54 additions & 54 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,43 @@ void impl_func_cum_ops(
}
}

Tensor& XPUNativeFunctions::cumsum_out(
static void cum_ops_meta(
const char* name,
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
std::optional<ScalarType> dtype,
Tensor& result) {
// Checking whether 'dim' is valid.
maybe_wrap_dim(dim, self.dim());

ScalarType out_dtype;

if (!result.defined()) {
auto is_integral =
at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype =
dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::empty_strided(
self.sizes(), self.strides(), self.options().dtype(out_dtype));
if (result.defined()) {
out_dtype = dtype.value_or(result.scalar_type());
at::xpu::resize_out(
result,
self.sizes(),
{},
self.options().dtype(out_dtype));
} else {
at::native::resize_output(result, self.sizes());
result.as_strided_(self.sizes(), self.strides());
auto is_integral = at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype = dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::xpu::create_out(
self.sizes(),
{},
self.options().dtype(out_dtype));
}

namedinference::propagate_names(result, self);
}

Tensor& XPUNativeFunctions::cumsum_out(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
Tensor& result) {
cum_ops_meta("cumsum", self, dim, dtype, result);

impl_func_cum_ops(self, dim, result, at::native::xpu::cumsum_kernel);
return result;
}
Expand All @@ -68,14 +83,40 @@ Tensor XPUNativeFunctions::cumsum(
int64_t dim,
c10::optional<ScalarType> dtype) {
Tensor result;
return cumsum_out(self, dim, dtype, result);
return XPUNativeFunctions::cumsum_out(self, dim, dtype, result);
}

Tensor& XPUNativeFunctions::cumsum_(
Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
return cumsum_out(self, dim, dtype, self);
return XPUNativeFunctions::cumsum_out(self, dim, dtype, self);
}

Tensor& XPUNativeFunctions::cumprod_out(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
Tensor& result) {
cum_ops_meta("cumprod", self, dim, dtype, result);

impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel);
return result;
}

Tensor XPUNativeFunctions::cumprod(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
Tensor result;
return XPUNativeFunctions::cumprod_out(self, dim, dtype, result);
}

Tensor& XPUNativeFunctions::cumprod_(
Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
return XPUNativeFunctions::cumprod_out(self, dim, dtype, self);
}

static ScalarType infer_dtype_from_optional(
Expand Down Expand Up @@ -803,45 +844,4 @@ Tensor XPUNativeFunctions::amin(
return out;
}

Tensor& XPUNativeFunctions::cumprod_out(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype,
Tensor& result) {
// Checking whether 'dim' is valid.
maybe_wrap_dim(dim, self.dim());

ScalarType out_dtype;

if (!result.defined()) {
auto is_integral =
at::isIntegralType(self.scalar_type(), /*includeBool=*/true);
out_dtype =
dtype.value_or(is_integral ? ScalarType::Long : self.scalar_type());
result = at::empty_strided(
self.sizes(), self.strides(), self.options().dtype(out_dtype));
} else {
at::native::resize_output(result, self.sizes());
result.as_strided_(self.sizes(), self.strides());
}

impl_func_cum_ops(self, dim, result, at::native::xpu::cumprod_kernel);
return result;
}

Tensor XPUNativeFunctions::cumprod(
const Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
Tensor result;
return XPUNativeFunctions::cumprod_out(self, dim, dtype, result);
}

Tensor& XPUNativeFunctions::cumprod_(
Tensor& self,
int64_t dim,
c10::optional<ScalarType> dtype) {
return XPUNativeFunctions::cumprod_out(self, dim, dtype, self);
}

} // namespace at
5 changes: 4 additions & 1 deletion src/ATen/native/xpu/sycl/CumprodKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

namespace at::native::xpu {

void cumprod_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim);
void launch_cumprod_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim);

} // namespace at::native::xpu
5 changes: 4 additions & 1 deletion src/ATen/native/xpu/sycl/CumsumKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

namespace at::native::xpu {

void cumsum_kernel_impl(const Tensor& result, const Tensor& self, int64_t dim);
void launch_cumsum_kernel(
const Tensor& result,
const Tensor& self,
int64_t dim);

} // namespace at::native::xpu
5 changes: 3 additions & 2 deletions src/ATen/native/xpu/sycl/ScanKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ static c10::MaybeOwned<Tensor> contiguous_out_arg(const Tensor& tensor) {
void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

cumsum_kernel_impl(*result_, self, dim);
launch_cumsum_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
Expand All @@ -39,10 +39,11 @@ void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

cumprod_kernel_impl(*result_, self, dim);
launch_cumprod_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
}

} // namespace at::native::xpu

0 comments on commit 365b12c

Please sign in to comment.