From 7259c016daca52d7641bb79c0c62e8fa9809649c Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Wed, 28 Aug 2024 22:59:59 +0800 Subject: [PATCH] Update Math.h --- src/ATen/native/xpu/sycl/Math.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Math.h b/src/ATen/native/xpu/sycl/Math.h index 746f761fc..3a66e5983 100644 --- a/src/ATen/native/xpu/sycl/Math.h +++ b/src/ATen/native/xpu/sycl/Math.h @@ -9,12 +9,12 @@ namespace at::native::xpu { * For licensing information, please refer to the cpu implementation located in * "ATen/native/Math.h". */ -template +template static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // [C++ Standard Reference: Gamma Function] // https://en.cppreference.com/w/cpp/numeric/math/tgamma using accscalar_t = at::acc_type_device; - static const accscalar_t PI_f64 = 3.14159265358979323846; + static const pi_t PI_f64 = 3.14159265358979323846; const accscalar_t PSI_10 = 2.25175258906672110764; const accscalar_t A[] = { 8.33333333333333333333E-2, @@ -46,8 +46,8 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) { // mathematically equivalent since both x and r are in radians and tan() has // a periodicity of pi, in practice the computation of pi * x is a source of // error (when |x| > 1). - accscalar_t q, r; - r = std::modf(static_cast(x), &q); + pi_t q, r; + r = std::modf(static_cast(x), &q); result = static_cast(-PI_f64 / std::tan(PI_f64 * r)); x = 1 - x; }