Skip to content

Commit

Permalink
Add aten::i0 and its variants. (#1026)
Browse files Browse the repository at this point in the history
- [x] i0.out
- [x] i0
- [x] i0_
  • Loading branch information
Kanya-Mo authored Oct 30, 2024
1 parent 38969d9 commit 43dfdbb
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel);
REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel);
REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel);
REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel);
REGISTER_XPU_DISPATCH(i0_stub, &xpu::i0_kernel);
REGISTER_XPU_DISPATCH(special_i0e_stub, &xpu::i0e_kernel);
REGISTER_XPU_DISPATCH(special_i1_stub, &xpu::i1_kernel);
REGISTER_XPU_DISPATCH(special_i1e_stub, &xpu::i1e_kernel);
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"frexp.Tensor_out",
"_fused_moving_avg_obs_fq_helper",
"geqrf",
"i0.out",
"igammac.out",
"igamma.out",
"index_reduce.out",
Expand Down
17 changes: 17 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
});
}

template <typename scalar_t>
struct I0Functor {
scalar_t operator()(scalar_t a) const {
using opmath_t = at::opmath_type<scalar_t>;
return calc_i0<opmath_t>(a);
}
};

void i0_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half,
ScalarType::BFloat16,
iter.common_dtype(),
"i0_xpu",
[&]() { gpu_kernel(iter, I0Functor<scalar_t>()); });
}

template <typename scalar_t>
struct I0eFunctor {
scalar_t operator()(scalar_t a) const {
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ TORCH_XPU_API void logit_kernel(
TensorIteratorBase& iter,
const Scalar& eps_scalar);

TORCH_XPU_API void i0_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void i0e_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void i1_kernel(TensorIteratorBase& iter);
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"hardswish",
"nn.functional.hardshrink",
"nn.functional.mish",
"i0",
"index_add",
"index_fill",
"index_put",
Expand Down
17 changes: 17 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6865,6 +6865,23 @@
- func: index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor
variants: function, method

- func: i0(Tensor self) -> Tensor
structured_delegate: i0.out
variants: function, method
tags: pointwise

- func: i0_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: i0.out
variants: function, method
tags: pointwise

- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
dispatch:
XPU: i0_out
tags: pointwise

- func: special_i0e(Tensor self) -> Tensor
python_module: special
variants: function
Expand Down

0 comments on commit 43dfdbb

Please sign in to comment.