Skip to content

Commit

Permalink
add nan_to_num.out
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel committed Jul 11, 2024
1 parent 0253fb9 commit 2dfffba
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 2 deletions.
24 changes: 24 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,4 +564,28 @@ Tensor& XPUNativeFunctions::ceil_out(const Tensor& self, Tensor& out) {
return out;
}

Tensor& XPUNativeFunctions::nan_to_num_out(
const Tensor& self,
std::optional<double> nan,
std::optional<double> pos_inf,
std::optional<double> neg_inf,
Tensor& result) {
TORCH_CHECK(
self.scalar_type() == result.scalar_type(),
"nan_to_num: dtype of out: ",
result.scalar_type(),
" should be same as input: ",
self.scalar_type());

if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
at::native::resize_output(result, self.sizes());
result.copy_(self);
return result;
}

auto iter = TensorIterator::unary_op(result, self);
native::xpu::nan_to_num_kernel(iter, nan, pos_inf, neg_inf);
return result;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"nanmedian",
"nanmedian.dim_values",
"nansum",
"nan_to_num.out",
"nextafter.out",
"norm.out",
"ormqr",
Expand Down
97 changes: 96 additions & 1 deletion src/ATen/native/xpu/sycl/UnaryKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <ATen/ATen.h>

#include <ATen/Dispatch.h>
#include <ATen/NumericUtils.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorIterator.h>
#include <c10/core/ScalarType.h>
Expand Down Expand Up @@ -129,4 +129,99 @@ void exp_kernel(TensorIteratorBase& iter) {
}
}

template <typename scalar_t>
static inline scalar_t _nan_to_num_replace(
scalar_t a,
scalar_t nan_replacement,
scalar_t pos_inf_replacement,
scalar_t neg_inf_replacement) {
return at::_isnan(a) ? nan_replacement
: (a == std::numeric_limits<scalar_t>::infinity()
? pos_inf_replacement
: (a == -std::numeric_limits<scalar_t>::infinity()
? neg_inf_replacement
: a));
}

template <typename scalar_t, typename value_t>
struct NanToNumComplexFunctor {
scalar_t operator()(scalar_t a) const {
value_t res_real = _nan_to_num_replace(
a.real(), nan_replacement_, pos_inf_replacement_, neg_inf_replacement_);
value_t res_imag = _nan_to_num_replace(
a.imag(), nan_replacement_, pos_inf_replacement_, neg_inf_replacement_);
return scalar_t(res_real, res_imag);
}
NanToNumComplexFunctor(
value_t nan_replacement,
value_t pos_inf_replacement,
value_t neg_inf_replacement)
: nan_replacement_(nan_replacement),
pos_inf_replacement_(pos_inf_replacement),
neg_inf_replacement_(neg_inf_replacement) {}

private:
value_t nan_replacement_;
value_t pos_inf_replacement_;
value_t neg_inf_replacement_;
};

template <typename scalar_t>
struct NanToNumFunctor {
scalar_t operator()(scalar_t a) const {
return _nan_to_num_replace(
a, nan_replacement_, pos_inf_replacement_, neg_inf_replacement_);
}
NanToNumFunctor(
scalar_t nan_replacement,
scalar_t pos_inf_replacement,
scalar_t neg_inf_replacement)
: nan_replacement_(nan_replacement),
pos_inf_replacement_(pos_inf_replacement),
neg_inf_replacement_(neg_inf_replacement) {}

private:
scalar_t nan_replacement_;
scalar_t pos_inf_replacement_;
scalar_t neg_inf_replacement_;
};

void nan_to_num_kernel(
TensorIteratorBase& iter,
std::optional<double> nan,
std::optional<double> pos_inf,
std::optional<double> neg_inf) {
if (isComplexType(iter.dtype())) {
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "nan_to_num_xpu", [&]() {
using value_t = scalar_t::value_type;
value_t nan_replacement = static_cast<value_t>(nan.value_or(0.));
value_t pos_inf_replacement = pos_inf.has_value()
? static_cast<value_t>(pos_inf.value())
: std::numeric_limits<value_t>::max();
value_t neg_inf_replacement = neg_inf.has_value()
? static_cast<value_t>(neg_inf.value())
: std::numeric_limits<value_t>::lowest();
gpu_kernel(
iter,
NanToNumComplexFunctor<scalar_t, value_t>(
nan_replacement, pos_inf_replacement, neg_inf_replacement));
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, iter.dtype(), "nan_to_num_xpu", [&]() {
scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
scalar_t pos_inf_replacement = pos_inf.has_value()
? static_cast<scalar_t>(pos_inf.value())
: std::numeric_limits<scalar_t>::max();
scalar_t neg_inf_replacement = neg_inf.has_value()
? static_cast<scalar_t>(neg_inf.value())
: std::numeric_limits<scalar_t>::lowest();
gpu_kernel(
iter,
NanToNumFunctor<scalar_t>(
nan_replacement, pos_inf_replacement, neg_inf_replacement));
});
}
}

} // namespace at::native::xpu
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ void bitwise_not_kernel(TensorIteratorBase& iter);

void exp_kernel(TensorIteratorBase& iter);

void nan_to_num_kernel(
TensorIteratorBase& iter,
std::optional<double> nan,
std::optional<double> pos_inf,
std::optional<double> neg_inf);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
"renorm",
"lerp",
"conj_physical",
"nan_to_num",
]


Expand Down
1 change: 1 addition & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,4 @@ supported:
- ceil
- ceil_
- ceil.out
- nan_to_num.out

0 comments on commit 2dfffba

Please sign in to comment.