Skip to content

Commit

Permalink
erfinv: use existing implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hjhee committed Jul 23, 2024
1 parent 1e61056 commit 2dcf59d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 78 deletions.
65 changes: 7 additions & 58 deletions src/ATen/native/xpu/sycl/UnarySpecialOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,67 +79,16 @@ void erfc_kernel(TensorIteratorBase& iter) {

template <typename scalar_t>
struct ErfinvFunctor {
using opmath_type = at::opmath_type<scalar_t>;

scalar_t operator()(scalar_t in) const {
scalar_t out;
opmath_type z, num, dem;

auto x = static_cast<opmath_type>(in);
if (std::fabs(x) > 1.0f) {
out = static_cast<scalar_t>(NAN);
return out;
}
if (std::fabs(x) == 1.0f) {
out = static_cast<scalar_t>(
(std::copysign(1.0, static_cast<double>(x))) *
(std::numeric_limits<double>::infinity()));
return out;
}
if (std::fabs(x) <= 0.7f) {
z = x * x;
num = (((a_[3] * z + a_[2]) * z + a_[1]) * z + a_[0]);
dem =
((((b_[3] * z + b_[2]) * z + b_[1]) * z + b_[0]) * z +
static_cast<opmath_type>(1.0));
out = x * num / dem;
} else {
z = static_cast<opmath_type>(
std::sqrt(-std::log((1.0 - std::fabs(x)) / 2.0)));
num = ((c_[3] * z + c_[2]) * z + c_[1]) * z + c_[0];
dem = (d_[1] * z + d_[0]) * z + static_cast<opmath_type>(1.0);
out = static_cast<scalar_t>(
static_cast<opmath_type>(std::copysign(1.0, static_cast<double>(x))) *
num / dem);
}
out = out -
static_cast<scalar_t>(
(std::erf(static_cast<double>(out)) - x) /
((2.0 / std::sqrt(PI_f64_)) * std::exp(-x * x)));
out = out -
static_cast<scalar_t>(
(std::erf(static_cast<double>(out)) - x) /
((2.0 / std::sqrt(PI_f64_)) * std::exp(-x * x)));
return out;
return calc_erfinv(in);
}
};

static constexpr double PI_f64_ = 3.14159265358979323846;
static constexpr std::array<opmath_type, 4> a_ = {
0.886226899,
-1.645349621,
0.914624893,
-0.140543331};
static constexpr std::array<opmath_type, 4> b_ = {
-2.118377725,
1.442710462,
-0.329097515,
0.012229801};
static constexpr std::array<opmath_type, 4> c_ = {
-1.970840454,
-1.624906493,
3.429567803,
1.641345311};
static constexpr std::array<opmath_type, 2> d_ = {3.543889200, 1.637067800};
template <>
struct ErfinvFunctor<c10::Half> {
c10::Half operator()(c10::Half in) const {
return calc_erfinv(float(in));
}
};

void erfinv_kernel(TensorIteratorBase& iter) {
Expand Down
20 changes: 0 additions & 20 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1547,26 +1547,6 @@ def launch_test(test_case, skip_list=None, exe_list=None):
# Relative difference: 6.156719153309558e-06 (up to 1e-06 allowed)
"test_log1p_complex_xpu_complex64",

# CPU MKL::erfinv vs XPU impl. At most 6.e-06
# Greatest absolute difference: 5.250126961175994e-06 at index (0,) (up to 1e-07 allowed)
# Greatest relative difference: 1.680894105274219e-06 at index (0,) (up to 1e-07 allowed)
"test_reference_numerics_large__refs_erfinv_xpu_float64",
# Greatest absolute difference: 5.250126961175994e-06 at index (0,) (up to 1e-07 allowed)
# Greatest relative difference: 1.680894105274219e-06 at index (0,) (up to 1e-07 allowed)
"test_reference_numerics_large_erfinv_xpu_float64",
# Greatest absolute difference: 4.829411781148707e-06 at index (690, 855) (up to 1e-07 allowed)
# Greatest relative difference: 1.5588752485769885e-06 at index (690, 855) (up to 1e-07 allowed)
"test_reference_numerics_normal__refs_erfinv_xpu_float64",
# Greatest absolute difference: 4.829411781148707e-06 at index (690, 855) (up to 1e-07 allowed)
# Greatest relative difference: 1.5588752485769885e-06 at index (690, 855) (up to 1e-07 allowed)
"test_reference_numerics_normal_erfinv_xpu_float64",
# Greatest absolute difference: 5.250126961175994e-06 at index (96,) (up to 1e-07 allowed)
# Greatest relative difference: 1.680894105274219e-06 at index (96,) (up to 1e-07 allowed)
"test_reference_numerics_small__refs_erfinv_xpu_float64",
# Greatest absolute difference: 5.250126961175994e-06 at index (96,) (up to 1e-07 allowed)
# Greatest relative difference: 1.680894105274219e-06 at index (96,) (up to 1e-07 allowed)
"test_reference_numerics_small_erfinv_xpu_float64",

# Issue: https://github.com/intel/torch-xpu-ops/issues/622
# Mismatched elements: 8 / 943593 (0.0%)
# Greatest absolute difference: inf at index (9, 860) (up to 0.001 allowed)
Expand Down

0 comments on commit 2dcf59d

Please sign in to comment.