Skip to content

Commit

Permalink
Update Math.h
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Aug 28, 2024
1 parent 6df96d2 commit 2b095f5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/ATen/native/xpu/sycl/Math.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ namespace at::native::xpu {
* For licensing information, please refer to the cpu implementation located in
* "ATen/native/Math.h".
*/
template <typename scalar_t>
template <typename scalar_t, typename pi_t = double>
static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// [C++ Standard Reference: Gamma Function]
// https://en.cppreference.com/w/cpp/numeric/math/tgamma
using accscalar_t = at::acc_type_device<scalar_t, kXPU>;
static const accscalar_t PI_f64 = 3.14159265358979323846;
static const pi_t PI_f64 = 3.14159265358979323846;
const accscalar_t PSI_10 = 2.25175258906672110764;
const accscalar_t A[] = {
8.33333333333333333333E-2,
Expand Down Expand Up @@ -46,8 +46,8 @@ static inline C10_HOST_DEVICE scalar_t calc_digamma(scalar_t in) {
// mathematically equivalent since both x and r are in radians and tan() has
// a periodicity of pi, in practice the computation of pi * x is a source of
// error (when |x| > 1).
accscalar_t q, r;
r = std::modf(static_cast<accscalar_t>(x), &q);
pi_t q, r;
r = std::modf(static_cast<pi_t>(x), &q);
result = static_cast<accscalar_t>(-PI_f64 / std::tan(PI_f64 * r));
x = 1 - x;
}
Expand Down

0 comments on commit 2b095f5

Please sign in to comment.