Skip to content

Commit

Permalink
gamma: Align data type in computation with the declaration of the hel…
Browse files Browse the repository at this point in the history
…per (#837)

Resolve the FP64 issue in the gamma operators on ARC.
  • Loading branch information
xytintel authored Aug 30, 2024
1 parent 2bd8ce9 commit d604c1d
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
41 changes: 22 additions & 19 deletions src/ATen/native/xpu/sycl/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ namespace at::native::xpu {
* For licensing information, please refer to the cpu implementation located in
* "ATen/native/Math.h".
*/
template <typename scalar_t>
template <typename scalar_t, typename pi_t = double>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function]
// https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
static const double PI_f64 = 3.14159265358979323846;
static const pi_t PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
8.33333333333333333333E-2,
Expand All @@ -27,15 +27,15 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
};

accscalar_t x = static_cast<accscalar_t>(in);
if (x == 0) {
if (x == accscalar_t(0)) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is ±0, ±∞ is returned
return std::copysign(static_cast<scalar_t>(INFINITY), -x);
}

bool x_is_integer = x == std::trunc(x);
accscalar_t result = 0;
if (x < 0) {
if (x < accscalar_t(0)) {
if (x_is_integer) {
// As per C++ standard for gamma related functions and SciPy,
// If the argument is a negative integer, NaN is returned
Expand All @@ -46,23 +46,23 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// mathematically equivalent since both x and r are in radians and tan() has
// a periodicity of pi, in practice the computation of pi * x is a source of
// error (when |x| > 1).
double q, r;
r = std::modf(static_cast<double>(x), &q);
pi_t q, r;
r = std::modf(static_cast<pi_t>(x), &q);
result = static_cast<accscalar_t>(-PI_f64 / std::tan(PI_f64 * r));
x = 1 - x;
}

while (x < 10) {
while (x < accscalar_t(10)) {
result -= 1 / x;
x += 1;
}
if (x == 10) {
if (x == accscalar_t(10)) {
return static_cast<scalar_t>(result + PSI_10);
}

accscalar_t y = 0;
if (x < 1.0e17) {
accscalar_t z = 1 / (x * x);
if (x < accscalar_t(1.0e17)) {
accscalar_t z = accscalar_t(1) / (x * x);

accscalar_t polevl_result = 0;
for (int i = 0; i <= 6; i++) {
Expand All @@ -82,20 +82,23 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) {
accscalar_t x = static_cast<accscalar_t>(in);
accscalar_t sign = +1;
accscalar_t result = 0;
if (x < 0.5f) {
if (x < accscalar_t(0.5)) {
sign = -1;
accscalar_t sin_pi_x = std::sin(PI * x);
result -= (PI * PI) / (sin_pi_x * sin_pi_x);
x = 1 - x;
x = accscalar_t(1) - x;
}
for (int i = 0; i < 6; ++i) {
result += 1 / (x * x);
x += 1;
result += accscalar_t(1) / (x * x);
x += accscalar_t(1);
}
const accscalar_t one = static_cast<scalar_t>(1);
const accscalar_t ixx = 1 / (x * x);
result += (1 + 1 / (2 * x) +
ixx * (one / 6 - ixx * (one / 30 - ixx * (one / 42)))) /
const accscalar_t one = accscalar_t(1);
const accscalar_t ixx = accscalar_t(1) / (x * x);
result +=
(accscalar_t(1) + accscalar_t(1) / (accscalar_t(2) * x) +
ixx *
(one / accscalar_t(6) -
ixx * (one / accscalar_t(30) - ixx * (one / accscalar_t(42))))) /
x;
return static_cast<scalar_t>(sign * result);
}
Expand All @@ -122,7 +125,7 @@ chbevl(scalar_t _x, const scalar_t array[], size_t len) {
b0 = _x * b1 - b2 + array[i];
}

return (0.5 * (b0 - b2));
return (scalar_t(0.5) * (b0 - b2));
}

/*
Expand Down
17 changes: 14 additions & 3 deletions src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@

namespace at::native::xpu {

template <typename scalar_t>
template <typename scalar_t, bool USE_FP64_PI>
struct DigammaFunctor {
scalar_t operator()(scalar_t a) const {
return calc_digamma(a);
if constexpr (USE_FP64_PI) {
return calc_digamma<scalar_t, double>(a);
} else {
using pi_t = at::acc_type_device<scalar_t, kXPU>;
return calc_digamma<scalar_t, pi_t>(a);
}
}
};

Expand All @@ -24,7 +29,13 @@ void digamma_kernel(TensorIteratorBase& iter) {
at::ScalarType::BFloat16,
iter.common_dtype(),
"digamma_xpu",
[&]() { gpu_kernel(iter, DigammaFunctor<scalar_t>()); });
[&]() {
if (syclHasFloat64()) {
gpu_kernel(iter, DigammaFunctor<scalar_t, true>());
} else {
gpu_kernel(iter, DigammaFunctor<scalar_t, false>());
}
});
}

template <typename scalar_t>
Expand Down
6 changes: 6 additions & 0 deletions src/comm/DeviceProperties.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,5 +190,11 @@ uint32_t syclNativeVectorWidth(
"Invalid data type to fetch native vector width!");
}

static inline bool syclHasFloat64(
at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
auto* dev_prop = at::xpu::getDeviceProperties(dev_id);
return dev_prop->has_fp64;
}

} // namespace sycl
} // namespace xpu

0 comments on commit d604c1d

Please sign in to comment.