From f2241385fee5108b68a2f24617ef7af009540a41 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:31:13 +0800 Subject: [PATCH] Add aten::special_scaled_modified_bessel_* and its variants (#1038) - [x] special_scaled_modified_bessel_k0 - [x] special_scaled_modified_bessel_k0.out - [x] special_scaled_modified_bessel_k1 - [x] special_scaled_modified_bessel_k1.out - [x] special_xlog1py - [x] special_xlog1py.out - [x] special_zeta - [x] special_zeta.out - [x] special_entr - [x] special_entr.out - [x] special_erfcx - [x] special_erfcx.out Co-authored-by: Yutao Xu --- src/ATen/native/xpu/Bessel.cpp | 33 +++++-- src/ATen/native/xpu/BinaryOps.cpp | 7 +- src/ATen/native/xpu/UnaryOps.cpp | 3 + src/ATen/native/xpu/XPUFallback.template | 6 -- src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp | 2 +- src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp | 4 +- src/ATen/native/xpu/sycl/BesselY0Kernel.cpp | 2 +- src/ATen/native/xpu/sycl/BesselY1Kernel.cpp | 4 +- .../native/xpu/sycl/BinaryMiscOpsKernels.cpp | 22 +++++ .../native/xpu/sycl/BinaryMiscOpsKernels.h | 2 + .../xpu/sycl/ScaledModifiedBesselK0Kernel.cpp | 25 +++++ .../xpu/sycl/ScaledModifiedBesselK0Kernel.h | 9 ++ .../xpu/sycl/ScaledModifiedBesselK1Kernel.cpp | 25 +++++ .../xpu/sycl/ScaledModifiedBesselK1Kernel.h | 9 ++ .../xpu/sycl/UnarySpecialOpsKernels.cpp | 36 +++++++ .../native/xpu/sycl/UnarySpecialOpsKernels.h | 4 + src/ATen/native/xpu/sycl/ZetaKernel.cpp | 26 ++++++ src/ATen/native/xpu/sycl/ZetaKernel.h | 9 ++ yaml/native/native_functions.yaml | 93 +++++++++++++++++++ 19 files changed, 300 insertions(+), 21 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.cpp create mode 100644 src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h create mode 100644 src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.cpp create mode 100644 src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h create mode 100644 src/ATen/native/xpu/sycl/ZetaKernel.cpp create mode 100644 src/ATen/native/xpu/sycl/ZetaKernel.h diff --git a/src/ATen/native/xpu/Bessel.cpp b/src/ATen/native/xpu/Bessel.cpp index 6bbaeef91..536cd91e4 100644 --- a/src/ATen/native/xpu/Bessel.cpp +++ b/src/ATen/native/xpu/Bessel.cpp @@ -1,7 +1,7 @@ -#include +#include #include #include -#include +#include #include #include #include @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include namespace at { @@ -18,10 +20,27 @@ REGISTER_XPU_DISPATCH(special_bessel_j0_stub, &xpu::bessel_j0_kernel); REGISTER_XPU_DISPATCH(special_bessel_j1_stub, &xpu::bessel_j1_kernel); REGISTER_XPU_DISPATCH(special_bessel_y0_stub, &xpu::bessel_y0_kernel); REGISTER_XPU_DISPATCH(special_bessel_y1_stub, &xpu::bessel_y1_kernel); -REGISTER_XPU_DISPATCH(special_modified_bessel_i0_stub, &xpu::modified_bessel_i0_kernel); -REGISTER_XPU_DISPATCH(special_modified_bessel_i1_stub, &xpu::modified_bessel_i1_kernel); -REGISTER_XPU_DISPATCH(special_modified_bessel_k0_stub, &xpu::modified_bessel_k0_kernel); -REGISTER_XPU_DISPATCH(special_modified_bessel_k1_stub, &xpu::modified_bessel_k1_kernel); -REGISTER_XPU_DISPATCH(special_spherical_bessel_j0_stub, &xpu::spherical_bessel_j0_kernel); +REGISTER_XPU_DISPATCH( + special_modified_bessel_i0_stub, + &xpu::modified_bessel_i0_kernel); +REGISTER_XPU_DISPATCH( + special_modified_bessel_i1_stub, + &xpu::modified_bessel_i1_kernel); +REGISTER_XPU_DISPATCH( + special_modified_bessel_k0_stub, + &xpu::modified_bessel_k0_kernel); +REGISTER_XPU_DISPATCH( + special_modified_bessel_k1_stub, + &xpu::modified_bessel_k1_kernel); +REGISTER_XPU_DISPATCH( + special_spherical_bessel_j0_stub, + &xpu::spherical_bessel_j0_kernel); +REGISTER_XPU_DISPATCH( + special_scaled_modified_bessel_k0_stub, + &xpu::scaled_modified_bessel_k0_kernel); +REGISTER_XPU_DISPATCH( + special_scaled_modified_bessel_k1_stub, + &xpu::scaled_modified_bessel_k1_kernel); + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index 030526f05..d9a1dc2a4 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -23,9 +24,9 @@ #include #include #include -#include -#include #include +#include +#include namespace at { namespace native { @@ -65,6 +66,8 @@ REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel); REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel); REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel); REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel); +REGISTER_XPU_DISPATCH(xlog1py_stub, &xpu::xlog1py_kernel); +REGISTER_XPU_DISPATCH(zeta_stub, &xpu::zeta_kernel); REGISTER_XPU_DISPATCH( hermite_polynomial_h_stub, &xpu::hermite_polynomial_h_kernel); diff --git a/src/ATen/native/xpu/UnaryOps.cpp b/src/ATen/native/xpu/UnaryOps.cpp index 2be0fd33c..680fc0dbc 100644 --- a/src/ATen/native/xpu/UnaryOps.cpp +++ b/src/ATen/native/xpu/UnaryOps.cpp @@ -61,6 +61,7 @@ REGISTER_XPU_DISPATCH(acos_stub, &xpu::acos_kernel); REGISTER_XPU_DISPATCH(acosh_stub, &xpu::acosh_kernel); REGISTER_XPU_DISPATCH(erf_stub, &xpu::erf_kernel); REGISTER_XPU_DISPATCH(erfc_stub, &xpu::erfc_kernel); + REGISTER_XPU_DISPATCH(erfinv_stub, &xpu::erfinv_kernel); REGISTER_XPU_DISPATCH(exp2_stub, &xpu::exp2_kernel); REGISTER_XPU_DISPATCH(expm1_stub, &xpu::expm1_kernel); @@ -86,6 +87,8 @@ REGISTER_XPU_DISPATCH(special_i1_stub, &xpu::i1_kernel); REGISTER_XPU_DISPATCH(special_i1e_stub, &xpu::i1e_kernel); REGISTER_XPU_DISPATCH(special_ndtri_stub, &xpu::ndtri_kernel); REGISTER_XPU_DISPATCH(special_log_ndtr_stub, &xpu::log_ndtr_kernel); +REGISTER_XPU_DISPATCH(special_erfcx_stub, &xpu::erfcx_kernel); +REGISTER_XPU_DISPATCH(special_entr_stub, &xpu::entr_kernel); } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 65f3b571c..9cf344188 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -201,12 +201,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_segment_reduce_backward", "sinc.out", "special_airy_ai.out", - "special_entr.out", - "special_erfcx.out", - "special_scaled_modified_bessel_k0.out", - "special_scaled_modified_bessel_k1.out", - "special_xlog1py.out", - "special_zeta.out", "_thnn_fused_gru_cell", "_to_sparse", "_to_sparse_csr", diff --git a/src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp b/src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp index 40e5ec602..7ca99812b 100644 --- a/src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp +++ b/src/ATen/native/xpu/sycl/BesselJ0Kernel.cpp @@ -21,4 +21,4 @@ void bessel_j0_kernel(TensorIteratorBase& iter) { }); } -} +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp b/src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp index 2c2066e0b..c4cb1203c 100644 --- a/src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp +++ b/src/ATen/native/xpu/sycl/BesselJ1Kernel.cpp @@ -1,7 +1,7 @@ #include -#include #include #include +#include #include #include @@ -24,4 +24,4 @@ void bessel_j1_kernel(TensorIteratorBase& iter) { }); } -} +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BesselY0Kernel.cpp b/src/ATen/native/xpu/sycl/BesselY0Kernel.cpp index a4e1a43f7..feb1ac9db 100644 --- a/src/ATen/native/xpu/sycl/BesselY0Kernel.cpp +++ b/src/ATen/native/xpu/sycl/BesselY0Kernel.cpp @@ -21,4 +21,4 @@ void bessel_y0_kernel(TensorIteratorBase& iter) { }); } -} +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BesselY1Kernel.cpp b/src/ATen/native/xpu/sycl/BesselY1Kernel.cpp index 879ab876c..ded21d104 100644 --- a/src/ATen/native/xpu/sycl/BesselY1Kernel.cpp +++ b/src/ATen/native/xpu/sycl/BesselY1Kernel.cpp @@ -1,7 +1,7 @@ #include -#include #include #include +#include #include #include @@ -21,4 +21,4 @@ void bessel_y1_kernel(TensorIteratorBase& iter) { }); } -} +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp index ba1dfdcd8..334d0172c 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp @@ -94,4 +94,26 @@ void xlogy_kernel(TensorIteratorBase& iter) { [&]() { gpu_kernel_with_scalars(iter, XlogyFunctor()); }); } +template +struct Xlog1pyFunctor { + scalar_t operator()(scalar_t x, scalar_t y) const { + if (at::_isnan(y)) { + return NAN; + } + if (x == 0) { + return 0; + } + return x * std::log1p(y); + } +}; + +void xlog1py_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + iter.common_dtype(), + "xlog1py_xpu", + [&]() { gpu_kernel_with_scalars(iter, Xlog1pyFunctor()); }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h index bc2f10715..e09f94bce 100644 --- a/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h +++ b/src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h @@ -12,4 +12,6 @@ TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta); TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter); +TORCH_XPU_API void xlog1py_kernel(TensorIteratorBase& iter); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.cpp b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.cpp new file mode 100644 index 000000000..98fca9e7a --- /dev/null +++ b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.cpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct ScaledModifiedBesselK0Functor { + scalar_t operator()(scalar_t a) const { + return scaled_modified_bessel_k0_forward(a); + } +}; + +void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES( + iter.common_dtype(), "scaled_modified_bessel_k0_xpu", [&]() { + gpu_kernel(iter, ScaledModifiedBesselK0Functor()); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h new file mode 100644 index 000000000..a16d9f3cd --- /dev/null +++ b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK0Kernel.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void scaled_modified_bessel_k0_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.cpp b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.cpp new file mode 100644 index 000000000..ec16bc9a2 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.cpp @@ -0,0 +1,25 @@ +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct ScaledModifiedBesselK1Functor { + scalar_t operator()(scalar_t a) const { + return scaled_modified_bessel_k1_forward(a); + } +}; + +void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES( + iter.common_dtype(), "scaled_modified_bessel_k1_xpu", [&]() { + gpu_kernel(iter, ScaledModifiedBesselK1Functor()); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h new file mode 100644 index 000000000..051751180 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ScaledModifiedBesselK1Kernel.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void scaled_modified_bessel_k1_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp index 71300b743..ca1bb3890 100644 --- a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp @@ -270,4 +270,40 @@ void log_ndtr_kernel(TensorIteratorBase& iter) { }); } +template +struct EntrFunctor { + scalar_t operator()(scalar_t x) const { + if (at::_isnan(x)) { + return x; + } else if (x > 0) { + return -x * std::log(x); + } else if (x == 0) { + return 0; + } + return static_cast(-std::numeric_limits::infinity()); + } +}; + +void entr_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + iter.common_dtype(), + "entr_xpu", + [&]() { gpu_kernel(iter, EntrFunctor()); }); +} + +template +struct ErfcxFunctor { + scalar_t operator()(scalar_t a) const { + return calc_erfcx(a); + } +}; + +void erfcx_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_xpu", [&]() { + gpu_kernel(iter, ErfcxFunctor()); + }); +} + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h index c85d47411..a0dc34eb0 100644 --- a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h +++ b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h @@ -30,4 +30,8 @@ TORCH_XPU_API void ndtri_kernel(TensorIteratorBase& iter); TORCH_XPU_API void log_ndtr_kernel(TensorIteratorBase& iter); +TORCH_XPU_API void entr_kernel(TensorIteratorBase& iter); + +TORCH_XPU_API void erfcx_kernel(TensorIteratorBase& iter); + } // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ZetaKernel.cpp b/src/ATen/native/xpu/sycl/ZetaKernel.cpp new file mode 100644 index 000000000..94b8c7dca --- /dev/null +++ b/src/ATen/native/xpu/sycl/ZetaKernel.cpp @@ -0,0 +1,26 @@ +#include +#include +#include +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct ZetaFunctor { + scalar_t operator()(scalar_t x, scalar_t q) const { + return zeta(x, q); + } +}; + +constexpr char zeta_name[] = "zeta"; +void zeta_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_xpu", [&]() { + gpu_kernel_with_scalars(iter, ZetaFunctor()); + }); +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/ZetaKernel.h b/src/ATen/native/xpu/sycl/ZetaKernel.h new file mode 100644 index 000000000..fdb4e3118 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ZetaKernel.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace at::native::xpu { + +TORCH_XPU_API void zeta_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu \ No newline at end of file diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 3a1ba30b9..c55f3c904 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -6865,6 +6865,99 @@ variants: function tags: pointwise +- func: special_scaled_modified_bessel_k0(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k0.out + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + XPU: special_scaled_modified_bessel_k0_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k1(Tensor x) -> Tensor + python_module: special + structured_delegate: special_scaled_modified_bessel_k1.out + variants: function + tags: pointwise + +- func: special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + XPU: special_scaled_modified_bessel_k1_out + python_module: special + structured_inherits: TensorIteratorBase + structured: True + variants: function + tags: pointwise + +- func: special_xlog1py(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + structured_delegate: special_xlog1py.out + tags: pointwise + +- func: special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + XPU: special_xlog1py_out + tags: pointwise + +- func: special_zeta(Tensor self, Tensor other) -> Tensor + device_check: NoCheck # TensorIterator + python_module: special + variants: function + structured_delegate: special_zeta.out + tags: pointwise + +- func: special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + device_check: NoCheck # TensorIterator + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + XPU: special_zeta_out + tags: pointwise + +- func: special_entr(Tensor self) -> Tensor + structured_delegate: special_entr.out + python_module: special + variants: function + tags: pointwise + +- func: special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + python_module: special + variants: function + dispatch: + XPU: special_entr_out + tags: pointwise + +- func: special_erfcx(Tensor self) -> Tensor + python_module: special + variants: function + structured_delegate: special_erfcx.out + tags: pointwise + +- func: special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + structured: True + structured_inherits: TensorIteratorBase + dispatch: + XPU: special_erfcx_out + tags: pointwise + - func: poisson(Tensor self, Generator? generator=None) -> Tensor device_check: NoCheck # TensorIterator dispatch: