From 9d5ed2e32b7640ca39229ee3856e2914dd93dcd5 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Mon, 30 Sep 2024 08:44:17 +0800 Subject: [PATCH] Add aten::trunc, aten::xlogy and thieir variants (#697) --- src/ATen/native/xpu/BinaryOps.cpp | 2 + src/ATen/native/xpu/ReduceOps.cpp | 4 +- src/ATen/native/xpu/UnaryOps.cpp | 2 + src/ATen/native/xpu/XPUFallback.template | 2 - .../native/xpu/sycl/BinaryMiscOpsKernels.cpp | 24 ++++++++++- .../native/xpu/sycl/BinaryMiscOpsKernels.h | 2 + .../native/xpu/sycl/UnaryFractionKernels.cpp | 37 ++++++++++++++++ .../native/xpu/sycl/UnaryFractionKernels.h | 2 + test/xpu/xpu_test_utils.py | 2 + yaml/native/native_functions.yaml | 43 +++++++++++++++++++ 10 files changed, 115 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index 31c6dd984..2854e8b0a 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,7 @@ REGISTER_XPU_DISPATCH(fmax_stub, &xpu::fmax_kernel); REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel); REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel); REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel); +REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel); TORCH_IMPL_FUNC(add_out_xpu) (const Tensor& self, diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index db72e5fbb..90c2ac642 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -298,8 +298,8 @@ void aminmax_impl( Tensor& min, Tensor& max) { auto dtype = self.scalar_type(); - TensorIterator iter = make_reduction( - "aminmax_xpu", min, max, self, dim_opt, keepdim, dtype); + TensorIterator iter = + make_reduction("aminmax_xpu", min, max, self, dim_opt, keepdim, dtype); if (iter.numel() != 0) { native::xpu::aminmax_kernel(iter); } diff --git a/src/ATen/native/xpu/UnaryOps.cpp b/src/ATen/native/xpu/UnaryOps.cpp index 119b7bab9..394014141 100644 --- a/src/ATen/native/xpu/UnaryOps.cpp +++ b/src/ATen/native/xpu/UnaryOps.cpp @@ -78,5 +78,7 @@ REGISTER_XPU_DISPATCH(nan_to_num_stub, &xpu::nan_to_num_kernel); REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel); REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel); REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel); +REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel); + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index fce5fa759..353e0e20f 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -281,7 +281,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "triangular_solve.X", "tril_indices", "triu_indices", - "trunc.out", "upsample_bicubic2d_backward.grad_input", "_upsample_bilinear2d_aa.out", "upsample_nearest3d.out", @@ -292,7 +291,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "upsample_trilinear3d.out", "_validate_compressed_sparse_indices", "vdot", - "xlogy.OutTensor", "_upsample_bicubic2d_aa.out", }; for (auto& op_name : fallback_list) { diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp index d96e5064e..ba1dfdcd8 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp @@ -2,12 +2,12 @@ #include #include +#include #include #include namespace at::native::xpu { - template struct MSEFunctor { scalar_t operator()(scalar_t a, scalar_t b) const { @@ -72,4 +72,26 @@ void huber_kernel(TensorIterator& iter, double delta) { }); } +template +struct XlogyFunctor { + scalar_t operator()(scalar_t x, scalar_t y) const { + if (at::_isnan(y)) { + return NAN; + } + if (x == 0) { + return 0; + } + return x * std::log(y); + } +}; + +void xlogy_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.common_dtype(), + "xlogy_xpu", + [&]() { gpu_kernel_with_scalars(iter, XlogyFunctor()); }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h index ffe08be3e..bc2f10715 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h @@ -10,4 +10,6 @@ TORCH_XPU_API void smooth_l1_kernel(TensorIteratorBase& iter, double beta); TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta); +TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp b/src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp index a8551c262..156547dbd 100644 --- a/src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp @@ -180,4 +180,41 @@ void floor_kernel(TensorIteratorBase& iter) { }); } +// We manually overload trunc because std::trunc does not work with std::complex +// types and ROCm. +template +inline scalar_t trunc_wrapper(scalar_t a) { + return static_cast(std::truncf(static_cast(a))); +} + +inline double trunc_wrapper(double a) { + return std::trunc(a); +} + +inline c10::complex trunc_wrapper(c10::complex a) { + return c10::complex( + std::truncf(static_cast(a.real())), + std::truncf(static_cast(a.imag()))); +} + +inline c10::complex trunc_wrapper(c10::complex a) { + return c10::complex( + std::trunc(static_cast(a.real())), + std::trunc(static_cast(a.imag()))); +} + +template +struct TruncFunctor { + scalar_t operator()(scalar_t a) const { + return trunc_wrapper(a); + } +}; + +void trunc_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "trunc_xpu", [&]() { + gpu_kernel(iter, TruncFunctor()); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/UnaryFractionKernels.h b/src/ATen/native/xpu/sycl/UnaryFractionKernels.h index a3a2015df..881546b8c 100644 --- a/src/ATen/native/xpu/sycl/UnaryFractionKernels.h +++ b/src/ATen/native/xpu/sycl/UnaryFractionKernels.h @@ -18,4 +18,6 @@ TORCH_XPU_API void round_decimals_kernel( TORCH_XPU_API void frac_kernel(TensorIteratorBase& iter); +TORCH_XPU_API void trunc_kernel(TensorIteratorBase& iter); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index d86a4503e..1e6696eea 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -199,6 +199,8 @@ "sign", "signbit", "round", + "trunc", + "xlogy", "nn.functional.embedding_bag", "bucketize", "searchsorted", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 607cb38ed..234753397 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -4435,6 +4435,29 @@ XPU: logit_out tags: pointwise +- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + structured_delegate: xlogy.OutTensor + variants: function, method + tags: pointwise + +# xlogy: inplace variant +- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + device_check: NoCheck # TensorIterator + variants: function, method + structured_delegate: xlogy.OutTensor + tags: pointwise + +# xlogy: out variant +- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + variants: function + dispatch: + XPU: xlogy_out + tags: pointwise + - func: erfinv(Tensor self) -> Tensor device_check: NoCheck # TensorIterator structured_delegate: erfinv.out @@ -4598,6 +4621,26 @@ XPU: floor_out tags: pointwise +- func: trunc(Tensor self) -> Tensor + structured_delegate: trunc.out + device_check: NoCheck # TensorIterator + variants: function, method + tags: [core, pointwise] + +- func: trunc_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: trunc.out + device_check: NoCheck # TensorIterator + variants: function, method + tags: pointwise + +- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + device_check: NoCheck # TensorIterator + dispatch: + XPU: trunc_out + tags: pointwise + - func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True