diff --git a/src/ATen/native/xpu/sycl/Math.h b/src/ATen/native/xpu/sycl/Math.h index 71e49d902..3a66e5983 100644 --- a/src/ATen/native/xpu/sycl/Math.h +++ b/src/ATen/native/xpu/sycl/Math.h @@ -9,12 +9,12 @@ namespace at::native::xpu { * For licensing information, please refer to the cpu implementation located in * "ATen/native/Math.h". */ -template +template static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // [C++ Standard Reference: Gamma Function] // https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type_device; - static const double PI_f64 = 3.14159265358979323846; + static const pi_t PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; const accscalar_t A[] = { 8.33333333333333333333E-2, @@ -27,7 +27,7 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { }; accscalar_t x = static_cast(in); - if (x == 0) { + if (x == accscalar_t(0)) { // As per C++ standard for gamma related functions and SciPy, // If the argument is ±0, ±∞ is returned return std::copysign(static_cast(INFINITY), -x); @@ -35,7 +35,7 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { bool x_is_integer = x == std::trunc(x); accscalar_t result = 0; - if (x < 0) { + if (x < accscalar_t(0)) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, // If the argument is a negative integer, NaN is returned @@ -46,23 +46,23 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // mathematically equivalent since both x and r are in radians and tan() has // a periodicity of pi, in practice the computation of pi * x is a source of // error (when |x| > 1). - double q, r; - r = std::modf(static_cast(x), &q); + pi_t q, r; + r = std::modf(static_cast(x), &q); result = static_cast(-PI_f64 / std::tan(PI_f64 * r)); x = 1 - x; } - while (x < 10) { + while (x < accscalar_t(10)) { result -= 1 / x; x += 1; } - if (x == 10) { + if (x == accscalar_t(10)) { return static_cast(result + PSI_10); } accscalar_t y = 0; - if (x < 1.0e17) { - accscalar_t z = 1 / (x * x); + if (x < accscalar_t(1.0e17)) { + accscalar_t z = accscalar_t(1) / (x * x); accscalar_t polevl_result = 0; for (int i = 0; i <= 6; i++) { @@ -82,20 +82,23 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) { accscalar_t x = static_cast(in); accscalar_t sign = +1; accscalar_t result = 0; - if (x < 0.5f) { + if (x < accscalar_t(0.5)) { sign = -1; accscalar_t sin_pi_x = std::sin(PI * x); result -= (PI * PI) / (sin_pi_x * sin_pi_x); - x = 1 - x; + x = accscalar_t(1) - x; } for (int i = 0; i < 6; ++i) { - result += 1 / (x * x); - x += 1; + result += accscalar_t(1) / (x * x); + x += accscalar_t(1); } - const accscalar_t one = static_cast(1); - const accscalar_t ixx = 1 / (x * x); - result += (1 + 1 / (2 * x) + - ixx * (one / 6 - ixx * (one / 30 - ixx * (one / 42)))) / + const accscalar_t one = accscalar_t(1); + const accscalar_t ixx = accscalar_t(1) / (x * x); + result += + (accscalar_t(1) + accscalar_t(1) / (accscalar_t(2) * x) + + ixx * + (one / accscalar_t(6) - + ixx * (one / accscalar_t(30) - ixx * (one / accscalar_t(42))))) / x; return static_cast(sign * result); } @@ -122,7 +125,7 @@ chbevl(scalar_t _x, const scalar_t array[], size_t len) { b0 = _x * b1 - b2 + array[i]; } - return (0.5 * (b0 - b2)); + return (scalar_t(0.5) * (b0 - b2)); } /* diff --git a/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp b/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp index 7b23e22ed..4a8a4a55e 100644 --- a/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp @@ -11,10 +11,15 @@ namespace at::native::xpu { -template +template struct DigammaFunctor { scalar_t operator()(scalar_t a) const { - return calc_digamma(a); + if constexpr (USE_FP64_PI) { + return calc_digamma(a); + } else { + using pi_t = at::acc_type_device; + return calc_digamma(a); + } } }; @@ -24,7 +29,13 @@ void digamma_kernel(TensorIteratorBase& iter) { at::ScalarType::BFloat16, iter.common_dtype(), "digamma_xpu", - [&]() { gpu_kernel(iter, DigammaFunctor()); }); + [&]() { + if (syclHasFloat64()) { + gpu_kernel(iter, DigammaFunctor()); + } else { + gpu_kernel(iter, DigammaFunctor()); + } + }); } template diff --git a/src/comm/DeviceProperties.h b/src/comm/DeviceProperties.h index 7dedeed7c..b98281357 100644 --- a/src/comm/DeviceProperties.h +++ b/src/comm/DeviceProperties.h @@ -190,5 +190,11 @@ uint32_t syclNativeVectorWidth( "Invalid data type to fetch native vector width!"); } +static inline bool syclHasFloat64( + at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + auto* dev_prop = at::xpu::getDeviceProperties(dev_id); + return dev_prop->has_fp64; +} + } // namespace sycl } // namespace xpu