Skip to content

Commit

Permalink
Add aten::trunc, aten::xlogy and thieir variants (#697)
Browse files Browse the repository at this point in the history
  • Loading branch information
yucai-intel authored Sep 30, 2024
1 parent 459f92c commit 9d5ed2e
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/native/xpu/sycl/BinaryKernels.h>
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
Expand Down Expand Up @@ -51,6 +52,7 @@ REGISTER_XPU_DISPATCH(fmax_stub, &xpu::fmax_kernel);
REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel);
REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel);
REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel);
REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel);

TORCH_IMPL_FUNC(add_out_xpu)
(const Tensor& self,
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ void aminmax_impl(
Tensor& min,
Tensor& max) {
auto dtype = self.scalar_type();
TensorIterator iter = make_reduction(
"aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
TensorIterator iter =
make_reduction("aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
if (iter.numel() != 0) {
native::xpu::aminmax_kernel(iter);
}
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,7 @@ REGISTER_XPU_DISPATCH(nan_to_num_stub, &xpu::nan_to_num_kernel);
REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel);
REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel);
REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel);
REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel);

} // namespace native
} // namespace at
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"triangular_solve.X",
"tril_indices",
"triu_indices",
"trunc.out",
"upsample_bicubic2d_backward.grad_input",
"_upsample_bilinear2d_aa.out",
"upsample_nearest3d.out",
Expand All @@ -292,7 +291,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"upsample_trilinear3d.out",
"_validate_compressed_sparse_indices",
"vdot",
"xlogy.OutTensor",
"_upsample_bicubic2d_aa.out",
};
for (auto& op_name : fallback_list) {
Expand Down
24 changes: 23 additions & 1 deletion src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#include <ATen/native/TensorIterator.h>
#include <comm/xpu_aten.h>

#include <ATen/NumericUtils.h>
#include <ATen/native/xpu/sycl/Loops.h>

#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>

namespace at::native::xpu {

template <typename scalar_t>
struct MSEFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
Expand Down Expand Up @@ -72,4 +72,26 @@ void huber_kernel(TensorIterator& iter, double delta) {
});
}

template <typename scalar_t>
struct XlogyFunctor {
scalar_t operator()(scalar_t x, scalar_t y) const {
if (at::_isnan(y)) {
return NAN;
}
if (x == 0) {
return 0;
}
return x * std::log(y);
}
};

void xlogy_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
iter.common_dtype(),
"xlogy_xpu",
[&]() { gpu_kernel_with_scalars(iter, XlogyFunctor<scalar_t>()); });
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ TORCH_XPU_API void smooth_l1_kernel(TensorIteratorBase& iter, double beta);

TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta);

TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
37 changes: 37 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,41 @@ void floor_kernel(TensorIteratorBase& iter) {
});
}

// We manually overload trunc because std::trunc does not work with std::complex
// types and ROCm.
template <typename scalar_t>
inline scalar_t trunc_wrapper(scalar_t a) {
return static_cast<scalar_t>(std::truncf(static_cast<float>(a)));
}

inline double trunc_wrapper(double a) {
return std::trunc(a);
}

inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
return c10::complex<float>(
std::truncf(static_cast<float>(a.real())),
std::truncf(static_cast<float>(a.imag())));
}

inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
return c10::complex<double>(
std::trunc(static_cast<double>(a.real())),
std::trunc(static_cast<double>(a.imag())));
}

template <typename scalar_t>
struct TruncFunctor {
scalar_t operator()(scalar_t a) const {
return trunc_wrapper(a);
}
};

void trunc_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "trunc_xpu", [&]() {
gpu_kernel(iter, TruncFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ TORCH_XPU_API void round_decimals_kernel(

TORCH_XPU_API void frac_kernel(TensorIteratorBase& iter);

TORCH_XPU_API void trunc_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
2 changes: 2 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,8 @@
"sign",
"signbit",
"round",
"trunc",
"xlogy",
"nn.functional.embedding_bag",
"bucketize",
"searchsorted",
Expand Down
43 changes: 43 additions & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4435,6 +4435,29 @@
XPU: logit_out
tags: pointwise

- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: xlogy.OutTensor
variants: function, method
tags: pointwise

# xlogy: inplace variant
- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
device_check: NoCheck # TensorIterator
variants: function, method
structured_delegate: xlogy.OutTensor
tags: pointwise

# xlogy: out variant
- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
variants: function
dispatch:
XPU: xlogy_out
tags: pointwise

- func: erfinv(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
structured_delegate: erfinv.out
Expand Down Expand Up @@ -4598,6 +4621,26 @@
XPU: floor_out
tags: pointwise

- func: trunc(Tensor self) -> Tensor
structured_delegate: trunc.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: [core, pointwise]

- func: trunc_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: trunc.out
device_check: NoCheck # TensorIterator
variants: function, method
tags: pointwise

- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
XPU: trunc_out
tags: pointwise

- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
structured: True
Expand Down

0 comments on commit 9d5ed2e

Please sign in to comment.