From 7f8179822bce260067cde23f3a2a84ea2b109209 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 12:28:33 +0800 Subject: [PATCH 1/6] Update Math.h --- src/ATen/native/xpu/sycl/Math.h | 39 ++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Math.h b/src/ATen/native/xpu/sycl/Math.h index 71e49d902..746f761fc 100644 --- a/src/ATen/native/xpu/sycl/Math.h +++ b/src/ATen/native/xpu/sycl/Math.h @@ -14,7 +14,7 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // [C++ Standard Reference: Gamma Function] // https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type_device; - static const double PI_f64 = 3.14159265358979323846; + static const accscalar_t PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; const accscalar_t A[] = { 8.33333333333333333333E-2, @@ -27,7 +27,7 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { }; accscalar_t x = static_cast(in); - if (x == 0) { + if (x == accscalar_t(0)) { // As per C++ standard for gamma related functions and SciPy, // If the argument is ±0, ±∞ is returned return std::copysign(static_cast(INFINITY), -x); @@ -35,7 +35,7 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { bool x_is_integer = x == std::trunc(x); accscalar_t result = 0; - if (x < 0) { + if (x < accscalar_t(0)) { if (x_is_integer) { // As per C++ standard for gamma related functions and SciPy, // If the argument is a negative integer, NaN is returned @@ -46,23 +46,23 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // mathematically equivalent since both x and r are in radians and tan() has // a periodicity of pi, in practice the computation of pi * x is a source of // error (when |x| > 1). - double q, r; - r = std::modf(static_cast(x), &q); + accscalar_t q, r; + r = std::modf(static_cast(x), &q); result = static_cast(-PI_f64 / std::tan(PI_f64 * r)); x = 1 - x; } - while (x < 10) { + while (x < accscalar_t(10)) { result -= 1 / x; x += 1; } - if (x == 10) { + if (x == accscalar_t(10)) { return static_cast(result + PSI_10); } accscalar_t y = 0; - if (x < 1.0e17) { - accscalar_t z = 1 / (x * x); + if (x < accscalar_t(1.0e17)) { + accscalar_t z = accscalar_t(1) / (x * x); accscalar_t polevl_result = 0; for (int i = 0; i <= 6; i++) { @@ -82,20 +82,23 @@ static inline C10_HOST_DEVICE scalar_t calc_trigamma(scalar_t in) { accscalar_t x = static_cast(in); accscalar_t sign = +1; accscalar_t result = 0; - if (x < 0.5f) { + if (x < accscalar_t(0.5)) { sign = -1; accscalar_t sin_pi_x = std::sin(PI * x); result -= (PI * PI) / (sin_pi_x * sin_pi_x); - x = 1 - x; + x = accscalar_t(1) - x; } for (int i = 0; i < 6; ++i) { - result += 1 / (x * x); - x += 1; + result += accscalar_t(1) / (x * x); + x += accscalar_t(1); } - const accscalar_t one = static_cast(1); - const accscalar_t ixx = 1 / (x * x); - result += (1 + 1 / (2 * x) + - ixx * (one / 6 - ixx * (one / 30 - ixx * (one / 42)))) / + const accscalar_t one = accscalar_t(1); + const accscalar_t ixx = accscalar_t(1) / (x * x); + result += + (accscalar_t(1) + accscalar_t(1) / (accscalar_t(2) * x) + + ixx * + (one / accscalar_t(6) - + ixx * (one / accscalar_t(30) - ixx * (one / accscalar_t(42))))) / x; return static_cast(sign * result); } @@ -122,7 +125,7 @@ chbevl(scalar_t _x, const scalar_t array[], size_t len) { b0 = _x * b1 - b2 + array[i]; } - return (0.5 * (b0 - b2)); + return (scalar_t(0.5) * (b0 - b2)); } /* From 6df96d2d8f054ac5ad28cd3638aa4c03b002e32f Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 22:16:00 +0800 Subject: [PATCH 2/6] Update DeviceProperties.h --- src/comm/DeviceProperties.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/comm/DeviceProperties.h b/src/comm/DeviceProperties.h index 0f4c084c8..01beffa74 100644 --- a/src/comm/DeviceProperties.h +++ b/src/comm/DeviceProperties.h @@ -185,5 +185,11 @@ uint32_t syclNativeVectorWidth( "Invalid data type to fetch native vector width!"); } +static inline bool syclHasFloat64( + at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + auto* dev_prop = at::xpu::getDeviceProperties(dev_id); + return dev_prop->has_fp64; +} + } // namespace sycl } // namespace xpu From 2b095f5863532674c19f28799c03ddd9d0cca549 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 22:16:54 +0800 Subject: [PATCH 3/6] Update Math.h --- src/ATen/native/xpu/sycl/Math.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Math.h b/src/ATen/native/xpu/sycl/Math.h index 746f761fc..3a66e5983 100644 --- a/src/ATen/native/xpu/sycl/Math.h +++ b/src/ATen/native/xpu/sycl/Math.h @@ -9,12 +9,12 @@ namespace at::native::xpu { * For licensing information, please refer to the cpu implementation located in * "ATen/native/Math.h". */ -template +template static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // [C++ Standard Reference: Gamma Function] // https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type_device; - static const accscalar_t PI_f64 = 3.14159265358979323846; + static const pi_t PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; const accscalar_t A[] = { 8.33333333333333333333E-2, @@ -46,8 +46,8 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // mathematically equivalent since both x and r are in radians and tan() has // a periodicity of pi, in practice the computation of pi * x is a source of // error (when |x| > 1). - accscalar_t q, r; - r = std::modf(static_cast(x), &q); + pi_t q, r; + r = std::modf(static_cast(x), &q); result = static_cast(-PI_f64 / std::tan(PI_f64 * r)); x = 1 - x; } From 02e2029dfcd82eb1e5d533a9318f70587d4a4c3e Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 22:17:45 +0800 Subject: [PATCH 4/6] Update UnaryGammaKernels.cpp --- src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp b/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp index 7b23e22ed..4a8a4a55e 100644 --- a/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp +++ b/src/ATen/native/xpu/sycl/UnaryGammaKernels.cpp @@ -11,10 +11,15 @@ namespace at::native::xpu { -template +template struct DigammaFunctor { scalar_t operator()(scalar_t a) const { - return calc_digamma(a); + if constexpr (USE_FP64_PI) { + return calc_digamma(a); + } else { + using pi_t = at::acc_type_device; + return calc_digamma(a); + } } }; @@ -24,7 +29,13 @@ void digamma_kernel(TensorIteratorBase& iter) { at::ScalarType::BFloat16, iter.common_dtype(), "digamma_xpu", - [&]() { gpu_kernel(iter, DigammaFunctor()); }); + [&]() { + if (syclHasFloat64()) { + gpu_kernel(iter, DigammaFunctor()); + } else { + gpu_kernel(iter, DigammaFunctor()); + } + }); } template From e8d6108545af7028b3c418f49506058c2c2b061f Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 23:15:38 +0800 Subject: [PATCH 5/6] Update skip_list_arc.py --- test/xpu/skip_list_arc.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/xpu/skip_list_arc.py b/test/xpu/skip_list_arc.py index c3f445b96..e239efb25 100644 --- a/test/xpu/skip_list_arc.py +++ b/test/xpu/skip_list_arc.py @@ -8,4 +8,7 @@ "test_tensor_creation_ops_xpu.py": ( "test_float_to_int_conversion_finite_xpu_int64", ), + "test_ops_xpu.py": ( + "test_compare_cpu_digamma_xpu_float32", + ), } From ec572e2a84d5bc472101c239291669b16b1d3a55 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 29 Aug 2024 13:20:50 +0800 Subject: [PATCH 6/6] Update skip_list_arc.py --- test/xpu/skip_list_arc.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/xpu/skip_list_arc.py b/test/xpu/skip_list_arc.py index e239efb25..c3f445b96 100644 --- a/test/xpu/skip_list_arc.py +++ b/test/xpu/skip_list_arc.py @@ -8,7 +8,4 @@ "test_tensor_creation_ops_xpu.py": ( "test_float_to_int_conversion_finite_xpu_int64", ), - "test_ops_xpu.py": ( - "test_compare_cpu_digamma_xpu_float32", - ), }