Skip to content

Commit

Permalink
Add aten::ceil (#463)
Browse files Browse the repository at this point in the history
- ceil.out
  - ceil
  - ceil_

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
hjhee and fengyuan14 authored Jul 5, 2024
1 parent 3fc911b commit 6781c4a
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 1 deletion.
36 changes: 36 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,40 @@ Tensor& XPUNativeFunctions::erfc_out(const Tensor& self, Tensor& out) {
return out;
}

TensorIterator ceil_meta(const Tensor& self, Tensor& out) {
TORCH_CHECK(!self.is_complex(), "ceil is not supported for complex inputs");
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
return iter;
}

Tensor XPUNativeFunctions::ceil(const Tensor& self) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
return self.clone();
}
Tensor out;
auto iter = ceil_meta(self, out);
native::xpu::ceil_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::ceil_(Tensor& self) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
return self;
}
auto iter = ceil_meta(self, self);
native::xpu::ceil_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::ceil_out(const Tensor& self, Tensor& out) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
out.copy_(self);
return out;
}
auto iter = ceil_meta(self, out);
native::xpu::ceil_kernel(iter);
return out;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"bitwise_right_shift.Tensor_out",
"cauchy_",
"_cdist_backward",
"ceil.out",
"channel_shuffle",
"cholesky",
"cholesky_inverse",
Expand Down
21 changes: 21 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,25 @@ void reciprocal_kernel(TensorIteratorBase& iter) {
[&]() { gpu_kernel(iter, ReciprocalFunctor<scalar_t>()); });
}

template <typename scalar_t>
struct CeilFunctor {
scalar_t operator()(const scalar_t a) const {
return std::ceil(a);
}
};

template <typename T>
struct CeilFunctor<c10::complex<T>> {
c10::complex<T> operator()(const c10::complex<T> a) const {
return c10::complex<T>(std::ceil(a.real()), std::ceil(a.imag()));
}
};

void ceil_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "ceil_xpu", [&]() {
gpu_kernel(iter, CeilFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ namespace at::native::xpu {

void reciprocal_kernel(TensorIteratorBase& iter);

void ceil_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
3 changes: 3 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,6 @@ supported:
- randperm.generator_out
- _amp_foreach_non_finite_check_and_unscale_
- _amp_update_scale_
- ceil
- ceil_
- ceil.out

0 comments on commit 6781c4a

Please sign in to comment.