Skip to content

Commit

Permalink
Move kernel pieces of cumsum/cumprod into unfied file. (#681)
Browse files Browse the repository at this point in the history
1. Fixing compilation error in BUILD_SEPARATE_OPS mode.
2. Although we aligned with file naming of PyTorch in-tree, putting host
only code in a file under ../xpu/sycl doesn't align with our design.

Signed-off-by: Feng Yuan <[email protected]>
  • Loading branch information
fengyuan14 authored Aug 5, 2024
1 parent f6d0f77 commit 2f75c47
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 49 deletions.
18 changes: 18 additions & 0 deletions src/ATen/native/xpu/sycl/CumprodKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,22 @@ void launch_cumprod_kernel(
});
}

static c10::MaybeOwned<Tensor> contiguous_out_arg(const Tensor& tensor) {
if (tensor.is_contiguous()) {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
return c10::MaybeOwned<Tensor>::owned(
at::empty(tensor.sizes(), tensor.options()));
}

void cumprod_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

launch_cumprod_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
}

} // namespace at::native::xpu
18 changes: 18 additions & 0 deletions src/ATen/native/xpu/sycl/CumsumKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,22 @@ void launch_cumsum_kernel(
});
}

static c10::MaybeOwned<Tensor> contiguous_out_arg(const Tensor& tensor) {
if (tensor.is_contiguous()) {
return c10::MaybeOwned<Tensor>::borrowed(tensor);
}
return c10::MaybeOwned<Tensor>::owned(
at::empty(tensor.sizes(), tensor.options()));
}

void cumsum_kernel(const Tensor& result, const Tensor& self, int64_t dim) {
auto result_ = contiguous_out_arg(result);

launch_cumsum_kernel(*result_, self, dim);

if (!result.is_same(*result_)) {
result.copy_(*result_);
}
}

} // namespace at::native::xpu
49 changes: 0 additions & 49 deletions src/ATen/native/xpu/sycl/ScanKernels.cpp

This file was deleted.

0 comments on commit 2f75c47

Please sign in to comment.