-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement fast tanh #6
base: master
Are you sure you want to change the base?
Conversation
Codecov Report
@@ Coverage Diff @@
## master #6 +/- ##
==========================================
+ Coverage 94.35% 94.72% +0.36%
==========================================
Files 5 6 +1
Lines 248 341 +93
==========================================
+ Hits 234 323 +89
- Misses 14 18 +4
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pull request!
My main comment is I think this approximation is quite a few operations, and we could possibly get away with something a bit simpler/faster.
Also, I've landed some large(ish) changes on master (#7), which require some adjustments (criterion for benchmarks now), and bring in the ieee754 library for convenient access to some basic float operations.
((28. * x2 + 3150.) * x2 + 62370.) * x2 + 135135. | ||
} | ||
|
||
/// Compute a fast approximation of the hyperbolic tangent of `x`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Compute a fast approximation of the hyperbolic tangent of `x`. | |
/// Compute a fast approximation of the hyperbolic tangent of `x` for -4 < `x` < 4. |
|
||
/// Compute a fast approximation of the hyperbolic tangent of `x`. | ||
/// | ||
/// For large |x|, the output may be outside of [-1, 1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer this to just make no guarantees about the behaviour at all, e.g.
/// This will return unspecified nonsense if `x` is doesn't
/// satisfy those constraints. Use `tanh` if correct handling is
/// required (at the expense of some speed).
/// See `atanh_raw` for a faster version that may return incorrect results for | ||
/// large `|x|` and `nan`. | ||
#[inline] | ||
pub fn tanh(x: f32) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this could be something like:
pub fn tanh(x: f32) -> f32 {
if x < -4.97 {
-1.
} else if x > 4.97 {
1.
} else {
// if x is NaN, it will propagate through the arithmetic
tanh_raw(x)
}
}
This is likely to be easier to vectorize, and does fewer operations. If you rebase/merge this PR onto the latest master (and so can use Ieee754::copy_sign
), this could even be:
pub fn tanh(x: f32) -> f32 {
if x.abs() > 4.97 {
// the true value |tanh(x)| > 0.9999 when |x| > 4.97, so
// rounding to ±1 is close enough
1_f32.copy_sign(x)
} else {
// |tanh_raw(x)| < 1 when |x| <= 4.97, so no post-processing is needed,
// and x being NaN is handled by propagating through the arithmetic
tanh_raw(x)
}
}
With this adjustment, a
and b
no longer need to be separate functions and can be inlined straight into tanh_raw
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't you think clipping like this might be problematic, because it results in discontinuities?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be).
pub fn tanh_raw(x: f32) -> f32 { | ||
// Implementation based on | ||
// https://varietyofsound.wordpress.com/2011/02/14/efficient-tanh-computation-using-lamberts-continued-fraction | ||
a(x) / b(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of interest, did you consider other approaches? E.g.
-
using a lower-degree continued fraction approximation instead of (7, 6), such as cutting it off at the level with the 5 (this seems to have maximum relative and absolute errors of about 0.02 if it's used on
[-2.3, 2.3]
and clipped to +/-1 outside that):(x2 + 15.) * x / (6. * x2 + 15.)
-
optimize the parameters of the approximation (just truncating series like the Taylor series of continued fractions won't be the most accurate approximation); for the form
(x2 + a) * x / (b * x2 + a)
on some interval[-limit, limit]
(like the above), I geta = 21.350693, b = 7.8355837, limit = 2.933833
as the best, with relative and absolute errors of approximately 0.0057, which is about as accurate as other functions in fast-math. (I used theapprox.py
script referenced at the end of this comment.) (I suspect this expensive form could benefit from optimizing its coefficients too.) -
Use the
tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
form, with an approximateexp
such as that described in https://stackoverflow.com/a/50379934/1256624, which could look something like the following in Rust (haven't tested):/// Computes an approximation to (exp(x), exp(-x)) #[inline] fn pm_exp(x: f32) -> (f32, f32) { const A: f32 = (1 << 23) as f32 / LN_2; const B: u32 = 127 << 23; let r = (A * x) as i32 as u32; (f32::from_bits(B.wrapping_add(r)), f32::from_bits(B.wrapping_sub(r))) } pub fn tanh_raw(x: f32) -> f32 { let (plus, minus) = pm_exp(x); (plus - minus) / (plus + minus) }
It could also use the
exp
now in the library, but it's more expensive (uses a quadratic approximation that requires pulling more info out of the floats), and I'm lead to believe that the above typically benefits from some cancellation of errors that doesn't occur for an isolatedexp
.
My suspicion is that 2. will be a good balance of speed and accuracy, but 3. could surprise me. Do you know any details about the above?
approx.py
Run like python approx.py
, may require Python 3.
import numpy as np
from scipy import optimize
def rel_error(approx, true):
# if the true value is 0, the approximate one must be too
return np.where(true != 0,
np.abs((approx - true) / true),
np.where(approx == 0, 0, np.inf))
def abs_error(approx, true):
return np.abs(approx - true)
f32 = np.float32
def approx(coeffs, points):
a, b, limit = coeffs
# approximate with (x^3 + a x) / (b x ^ 2 + a) on the interval
# [-limit, limit].
#
# Why this form? tanh is odd, so we should have odd / even (and
# so the unlisted coeffs must be zero), tanh(x) ~= x for small x
# (so we can share a in the top and bottom so it approximates a x
# / a = x when x is small).
points2 = points * points
poly = (points2 + a) * points / (b * points2 + a)
return np.where(np.abs(points) <= limit, poly, np.sign(points))
def evaluation(coeffs, points):
a = approx(coeffs, points)
t = np.tanh(points)
rel = rel_error(a, t).max()
abs = abs_error(a, t).max()
return (rel, abs)
start = np.array([15, 6, 2.3])
opt_points = np.linspace(-5, 5, 100001)
# optimize on the relative error
result = optimize.fmin(lambda c: evaluation(f32(c), f32(opt_points))[0], start, maxiter=10000)
final_values = approx(result, opt_points)
assert np.all((final_values >= -1) & (final_values <= 1)), "not allowed to overshoot"
rel, abs = evaluation(result, opt_points)
print("a = %s, b = %s, limit = %s" % tuple(f32(result)))
print("on [-5, 5]: rel error = %.6f, abs error = %.6f" % (rel, abs))
for x in np.arange(-5, 5.01, 0.5):
print("%7.4f: %f (%f)" % (x, approx(result, x), np.tanh(x)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #6 (comment).
Partially adresses huonw#1.
Co-Authored-By: vks <[email protected]>
I tried the implementations you suggested: Current implementation:
Suggested clipping:
It looks like there might be a small improvement, but the results are weird (with unexpected changes for Suggest clipping with optimized lower-order approximation:
This seems to result in good performance improvements (54% for
This is a bit slower than the truncated continued fraction. |
Should I switch to the implementation optimized for 0.0057 error tolerance? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think switching would be good, given how much of a performance improvement it is.
I think this library potentially needs to be restructured to give more control about errors, but for now, I think 0.0057 error tolerance is fine. Do you have something for which you might use this? If so, is that error tolerance acceptable?
/// See `atanh_raw` for a faster version that may return incorrect results for | ||
/// large `|x|` and `nan`. | ||
#[inline] | ||
pub fn tanh(x: f32) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a potential problem. An alternative would be to find when the approximation is exactly +/-1, and clip there instead of 4.97 (I think it should be symmetric?), so that the tanh approximation is continuous (although it's derivative won't be).
Partially adresses #1.