From 43dfdbbbccd6a95739c00873a1db87fab0b3863c Mon Sep 17 00:00:00 2001 From: Kanya-Mo <167922169+Kanya-Mo@users.noreply.github.com> Date: Wed, 30 Oct 2024 01:50:37 -0700 Subject: [PATCH] Add aten::i0 and its variants. (#1026) - [x] i0.out - [x] i0 - [x] i0_ --- src/ATen/native/xpu/UnaryOps.cpp | 1 + src/ATen/native/xpu/XPUFallback.template | 1 - .../native/xpu/sycl/UnarySpecialOpsKernels.cpp | 17 +++++++++++++++++ .../native/xpu/sycl/UnarySpecialOpsKernels.h | 2 ++ test/xpu/xpu_test_utils.py | 1 + yaml/native/native_functions.yaml | 17 +++++++++++++++++ 6 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/UnaryOps.cpp b/src/ATen/native/xpu/UnaryOps.cpp index bae3c3c39..2be0fd33c 100644 --- a/src/ATen/native/xpu/UnaryOps.cpp +++ b/src/ATen/native/xpu/UnaryOps.cpp @@ -80,6 +80,7 @@ REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel); REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel); REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel); REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel); +REGISTER_XPU_DISPATCH(i0_stub, &xpu::i0_kernel); REGISTER_XPU_DISPATCH(special_i0e_stub, &xpu::i0e_kernel); REGISTER_XPU_DISPATCH(special_i1_stub, &xpu::i1_kernel); REGISTER_XPU_DISPATCH(special_i1e_stub, &xpu::i1e_kernel); diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 68af798e2..65fa5d667 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -171,7 +171,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "frexp.Tensor_out", "_fused_moving_avg_obs_fq_helper", "geqrf", - "i0.out", "igammac.out", "igamma.out", "index_reduce.out", diff --git a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp index 6889e13c0..71300b743 100644 --- a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp @@ -176,6 +176,23 @@ void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) { }); } +template +struct I0Functor { + scalar_t operator()(scalar_t a) const { + using opmath_t = at::opmath_type; + return calc_i0(a); + } +}; + +void i0_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + ScalarType::Half, + ScalarType::BFloat16, + iter.common_dtype(), + "i0_xpu", + [&]() { gpu_kernel(iter, I0Functor()); }); +} + template struct I0eFunctor { scalar_t operator()(scalar_t a) const { diff --git a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h index 16518cf2d..c85d47411 100644 --- a/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h +++ b/src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h @@ -18,6 +18,8 @@ TORCH_XPU_API void logit_kernel( TensorIteratorBase& iter, const Scalar& eps_scalar); +TORCH_XPU_API void i0_kernel(TensorIteratorBase& iter); + TORCH_XPU_API void i0e_kernel(TensorIteratorBase& iter); TORCH_XPU_API void i1_kernel(TensorIteratorBase& iter); diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 7a2ddb75f..95807265d 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -82,6 +82,7 @@ "hardswish", "nn.functional.hardshrink", "nn.functional.mish", + "i0", "index_add", "index_fill", "index_put", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 1283b9f24..ef036d5f8 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -6865,6 +6865,23 @@ - func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor variants: function, method +- func: i0(Tensor self) -> Tensor + structured_delegate: i0.out + variants: function, method + tags: pointwise + +- func: i0_(Tensor(a!) self) -> Tensor(a!) + structured_delegate: i0.out + variants: function, method + tags: pointwise + +- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + structured: True + structured_inherits: TensorIteratorBase + dispatch: + XPU: i0_out + tags: pointwise + - func: special_i0e(Tensor self) -> Tensor python_module: special variants: function