Skip to content

Commit 9d5ed2e

Browse files
authored
Add aten::trunc, aten::xlogy and thieir variants (#697)
1 parent 459f92c commit 9d5ed2e

File tree

10 files changed

+115
-5
lines changed

10 files changed

+115
-5
lines changed

src/ATen/native/xpu/BinaryOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/native/xpu/sycl/BinaryKernels.h>
1212
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
1313
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
14+
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
1415
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
1516
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
1617
#include <ATen/native/xpu/sycl/CopysignKernel.h>
@@ -51,6 +52,7 @@ REGISTER_XPU_DISPATCH(fmax_stub, &xpu::fmax_kernel);
5152
REGISTER_XPU_DISPATCH(fmin_stub, &xpu::fmin_kernel);
5253
REGISTER_XPU_DISPATCH(lshift_stub, &xpu::lshift_kernel);
5354
REGISTER_XPU_DISPATCH(rshift_stub, &xpu::rshift_kernel);
55+
REGISTER_XPU_DISPATCH(xlogy_stub, &xpu::xlogy_kernel);
5456

5557
TORCH_IMPL_FUNC(add_out_xpu)
5658
(const Tensor& self,

src/ATen/native/xpu/ReduceOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,8 @@ void aminmax_impl(
298298
Tensor& min,
299299
Tensor& max) {
300300
auto dtype = self.scalar_type();
301-
TensorIterator iter = make_reduction(
302-
"aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
301+
TensorIterator iter =
302+
make_reduction("aminmax_xpu", min, max, self, dim_opt, keepdim, dtype);
303303
if (iter.numel() != 0) {
304304
native::xpu::aminmax_kernel(iter);
305305
}

src/ATen/native/xpu/UnaryOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,7 @@ REGISTER_XPU_DISPATCH(nan_to_num_stub, &xpu::nan_to_num_kernel);
7878
REGISTER_XPU_DISPATCH(round_stub, &xpu::round_kernel);
7979
REGISTER_XPU_DISPATCH(round_decimals_stub, &xpu::round_decimals_kernel);
8080
REGISTER_XPU_DISPATCH(floor_stub, &xpu::floor_kernel);
81+
REGISTER_XPU_DISPATCH(trunc_stub, &xpu::trunc_kernel);
82+
8183
} // namespace native
8284
} // namespace at

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
281281
"triangular_solve.X",
282282
"tril_indices",
283283
"triu_indices",
284-
"trunc.out",
285284
"upsample_bicubic2d_backward.grad_input",
286285
"_upsample_bilinear2d_aa.out",
287286
"upsample_nearest3d.out",
@@ -292,7 +291,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
292291
"upsample_trilinear3d.out",
293292
"_validate_compressed_sparse_indices",
294293
"vdot",
295-
"xlogy.OutTensor",
296294
"_upsample_bicubic2d_aa.out",
297295
};
298296
for (auto& op_name : fallback_list) {

src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#include <ATen/native/TensorIterator.h>
33
#include <comm/xpu_aten.h>
44

5+
#include <ATen/NumericUtils.h>
56
#include <ATen/native/xpu/sycl/Loops.h>
67

78
#include <ATen/native/xpu/sycl/BinaryMiscOpsKernels.h>
89

910
namespace at::native::xpu {
10-
1111
template <typename scalar_t>
1212
struct MSEFunctor {
1313
scalar_t operator()(scalar_t a, scalar_t b) const {
@@ -72,4 +72,26 @@ void huber_kernel(TensorIterator& iter, double delta) {
7272
});
7373
}
7474

75+
template <typename scalar_t>
76+
struct XlogyFunctor {
77+
scalar_t operator()(scalar_t x, scalar_t y) const {
78+
if (at::_isnan(y)) {
79+
return NAN;
80+
}
81+
if (x == 0) {
82+
return 0;
83+
}
84+
return x * std::log(y);
85+
}
86+
};
87+
88+
void xlogy_kernel(TensorIteratorBase& iter) {
89+
AT_DISPATCH_FLOATING_TYPES_AND2(
90+
at::ScalarType::Half,
91+
at::ScalarType::BFloat16,
92+
iter.common_dtype(),
93+
"xlogy_xpu",
94+
[&]() { gpu_kernel_with_scalars(iter, XlogyFunctor<scalar_t>()); });
95+
}
96+
7597
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/BinaryMiscOpsKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ TORCH_XPU_API void smooth_l1_kernel(TensorIteratorBase& iter, double beta);
1010

1111
TORCH_XPU_API void huber_kernel(TensorIterator& iter, double delta);
1212

13+
TORCH_XPU_API void xlogy_kernel(TensorIteratorBase& iter);
14+
1315
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,4 +180,41 @@ void floor_kernel(TensorIteratorBase& iter) {
180180
});
181181
}
182182

183+
// We manually overload trunc because std::trunc does not work with std::complex
184+
// types and ROCm.
185+
template <typename scalar_t>
186+
inline scalar_t trunc_wrapper(scalar_t a) {
187+
return static_cast<scalar_t>(std::truncf(static_cast<float>(a)));
188+
}
189+
190+
inline double trunc_wrapper(double a) {
191+
return std::trunc(a);
192+
}
193+
194+
inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
195+
return c10::complex<float>(
196+
std::truncf(static_cast<float>(a.real())),
197+
std::truncf(static_cast<float>(a.imag())));
198+
}
199+
200+
inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
201+
return c10::complex<double>(
202+
std::trunc(static_cast<double>(a.real())),
203+
std::trunc(static_cast<double>(a.imag())));
204+
}
205+
206+
template <typename scalar_t>
207+
struct TruncFunctor {
208+
scalar_t operator()(scalar_t a) const {
209+
return trunc_wrapper(a);
210+
}
211+
};
212+
213+
void trunc_kernel(TensorIteratorBase& iter) {
214+
AT_DISPATCH_FLOATING_TYPES_AND2(
215+
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "trunc_xpu", [&]() {
216+
gpu_kernel(iter, TruncFunctor<scalar_t>());
217+
});
218+
}
219+
183220
} // namespace at::native::xpu

src/ATen/native/xpu/sycl/UnaryFractionKernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ TORCH_XPU_API void round_decimals_kernel(
1818

1919
TORCH_XPU_API void frac_kernel(TensorIteratorBase& iter);
2020

21+
TORCH_XPU_API void trunc_kernel(TensorIteratorBase& iter);
22+
2123
} // namespace at::native::xpu

test/xpu/xpu_test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,8 @@
199199
"sign",
200200
"signbit",
201201
"round",
202+
"trunc",
203+
"xlogy",
202204
"nn.functional.embedding_bag",
203205
"bucketize",
204206
"searchsorted",

yaml/native/native_functions.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4435,6 +4435,29 @@
44354435
XPU: logit_out
44364436
tags: pointwise
44374437

4438+
- func: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
4439+
device_check: NoCheck # TensorIterator
4440+
structured_delegate: xlogy.OutTensor
4441+
variants: function, method
4442+
tags: pointwise
4443+
4444+
# xlogy: inplace variant
4445+
- func: xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
4446+
device_check: NoCheck # TensorIterator
4447+
variants: function, method
4448+
structured_delegate: xlogy.OutTensor
4449+
tags: pointwise
4450+
4451+
# xlogy: out variant
4452+
- func: xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
4453+
device_check: NoCheck # TensorIterator
4454+
structured: True
4455+
structured_inherits: TensorIteratorBase
4456+
variants: function
4457+
dispatch:
4458+
XPU: xlogy_out
4459+
tags: pointwise
4460+
44384461
- func: erfinv(Tensor self) -> Tensor
44394462
device_check: NoCheck # TensorIterator
44404463
structured_delegate: erfinv.out
@@ -4598,6 +4621,26 @@
45984621
XPU: floor_out
45994622
tags: pointwise
46004623

4624+
- func: trunc(Tensor self) -> Tensor
4625+
structured_delegate: trunc.out
4626+
device_check: NoCheck # TensorIterator
4627+
variants: function, method
4628+
tags: [core, pointwise]
4629+
4630+
- func: trunc_(Tensor(a!) self) -> Tensor(a!)
4631+
structured_delegate: trunc.out
4632+
device_check: NoCheck # TensorIterator
4633+
variants: function, method
4634+
tags: pointwise
4635+
4636+
- func: trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
4637+
structured: True
4638+
structured_inherits: TensorIteratorBase
4639+
device_check: NoCheck # TensorIterator
4640+
dispatch:
4641+
XPU: trunc_out
4642+
tags: pointwise
4643+
46014644
- func: replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)
46024645
python_module: nn
46034646
structured: True

0 commit comments

Comments
 (0)