From 3acece4f9f5ce0efcd2492e4331bbac2463b2755 Mon Sep 17 00:00:00 2001
From: Vinzent Steinberg <Vinzent.Steinberg@gmail.com>
Date: Tue, 8 Jan 2019 13:29:59 +0100
Subject: [PATCH] Implement fast atanh

Partially adresses #1.
---
 src/lib.rs  |   2 +
 src/tanh.rs | 191 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 193 insertions(+)
 create mode 100644 src/tanh.rs

diff --git a/src/lib.rs b/src/lib.rs
index cce39c1..4104d88 100755
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -32,9 +32,11 @@
 
 pub use log::{log2, log2_raw};
 pub use atan::{atan_raw, atan, atan2};
+pub use tanh::{tanh, tanh_raw};
 
 mod log;
 mod atan;
+mod tanh;
 
 #[doc(hidden)]
 pub mod float;
diff --git a/src/tanh.rs b/src/tanh.rs
new file mode 100644
index 0000000..f4e82e0
--- /dev/null
+++ b/src/tanh.rs
@@ -0,0 +1,191 @@
+/// Calculate the numerator of the `tanh` approximation.
+fn a(x: f32) -> f32 {
+    let x2 = x * x;
+    (((x2 + 378.) * x2 + 17325.) * x2 + 135135.) * x
+}
+
+/// Calculate the denominator of the `tanh` approximation.
+fn b(x: f32) -> f32 {
+    let x2 = x * x;
+    ((28. * x2 + 3150.) * x2 + 62370.) * x2 + 135135.
+}
+
+/// Compute a fast approximation of the hyperbolic tangent of `x`.
+///
+/// For large |x|, the output may be outside of [-1, 1].
+#[inline]
+pub fn tanh_raw(x: f32) -> f32 {
+    // Implementation based on
+    // https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction
+    a(x) / b(x)
+}
+
+/// Compute a fast approximation of the hyperbolic tangent of `x`.
+///
+/// See `atanh_raw` for a faster version that may return incorrect results for
+/// large `|x|` and `nan`.
+#[inline]
+pub fn tanh(x: f32) -> f32 {
+    if x.is_nan() {
+        return x;
+    }
+
+    let a = a(x);
+    if !a.is_finite() {
+        return if a < 0. { -1. } else { 1. };
+    }
+
+    let result = a / b(x);
+    if result > 1. {
+        return 1.;
+    }
+    if result < -1. {
+        return -1.;
+    }
+    result
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use quickcheck as qc;
+    use std::f32 as f;
+
+    /// Maximal absolute error.
+    const TOL: f32 = 0.0001;
+
+    #[test]
+    fn tanh_abs_err_qc() {
+        fn prop(x: f32) -> qc::TestResult {
+            let e = tanh(x);
+            let t = x.tanh();
+            let abs = (e - t).abs();
+
+            qc::TestResult::from_bool(abs < TOL)
+        }
+        qc::quickcheck(prop as fn(f32) -> qc::TestResult)
+    }
+
+    const PREC: u32 = 1 << 20;
+    #[test]
+    fn tanh_abs_err_exhaustive() {
+        for i in 0..PREC + 1 {
+            for j in -5..6 {
+                let x = (1.0 + i as f32 / PREC as f32) * 2f32.powi(j * 20);
+                {
+                    let e = tanh(x);
+                    let t = x.tanh();
+                    let abs = (e - t).abs();
+
+                    assert!(abs < TOL,
+                            "{:.8}: {:.8}, {:.8}. {:.4}", x, e, t, abs);
+                }
+                {
+                    let e = tanh(-x);
+                    let t = (-x).tanh();
+                    let abs = (e - t).abs();
+
+                    assert!(abs < TOL,
+                            "{:.8}: {:.8}, {:.8}. {:.4}", -x, e, t, abs);
+                }
+            }
+        }
+    }
+
+    #[test]
+    fn tanh_edge_cases() {
+        assert!(tanh(f::NAN).is_nan());
+        assert_eq!(tanh(f::NEG_INFINITY), -1.);
+        assert_eq!(tanh(f::INFINITY), 1.);
+    }
+
+    #[test]
+    fn tanh_denormals() {
+        fn prop(x: u8, y: u16) -> bool {
+            let signif = ((x as u32) << 16) | (y as u32);
+            let mut x = ::float::recompose(0, 1, signif);
+
+            for _ in 0..23 {
+                {
+                    let e = tanh(x);
+                    let t = x.tanh();
+                    let abs = (e - t).abs();
+                    if abs >= TOL {
+                        return false
+                    }
+                }
+                {
+                    let e = tanh(-x);
+                    let t = (-x).tanh();
+                    let abs = (e - t).abs();
+                    if abs >= TOL {
+                        return false
+                    }
+                }
+
+                x /= 2.0;
+            }
+            true
+        }
+        qc::quickcheck(prop as fn(u8, u16) -> bool)
+    }
+
+    #[test]
+    fn tanh_raw_denormals() {
+        fn prop(x: u8, y: u16) -> bool {
+            let signif = ((x as u32) << 16) | (y as u32);
+            let mut x = ::float::recompose(0, 1, signif);
+
+            for _ in 0..23 {
+                let e = tanh_raw(x);
+                let t = x.tanh();
+                let abs = (e - t).abs();
+                if abs >= TOL {
+                    return false
+                }
+
+                x /= 2.0;
+            }
+            true
+        }
+        qc::quickcheck(prop as fn(u8, u16) -> bool)
+    }
+}
+
+#[cfg(all(test, feature = "unstable"))]
+mod benches {
+    use test::{Bencher, black_box};
+
+    const TAB: &'static [f32] =
+        &[ 0.85708036,  2.43390621,  2.80163358,  2.55126348,  3.18046186,
+           2.88689427,  0.32215155,  0.07701401,  1.22922506,  0.4580259 ,
+           0.01257442,  4.23107197,  0.89538113,  1.65219582,  0.14632742,
+           1.68663984,  1.88125115,  2.16773942,  1.27461936,  1.03091265];
+
+    #[bench]
+    fn tanh(b: &mut Bencher) {
+        b.iter(|| {
+            for &x in black_box(TAB) {
+                black_box(super::tanh(x));
+            }
+        })
+    }
+
+    #[bench]
+    fn tanh_raw(b: &mut Bencher) {
+        b.iter(|| {
+            for &x in black_box(TAB) {
+                black_box(super::tanh_raw(x));
+            }
+        })
+    }
+
+    #[bench]
+    fn tanh_std(b: &mut Bencher) {
+        b.iter(|| {
+            for &x in black_box(TAB) {
+                black_box(x.tanh());
+            }
+        })
+    }
+}