-
Notifications
You must be signed in to change notification settings - Fork 53
more accurate sqrt function #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -281,40 +281,87 @@ impl<T: Float> Complex<T> { | |||||||||
/// | ||||||||||
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`. | ||||||||||
#[inline] | ||||||||||
pub fn sqrt(self) -> Self { | ||||||||||
if self.im.is_zero() { | ||||||||||
if self.re.is_sign_positive() { | ||||||||||
// simple positive real √r, and copy `im` for its sign | ||||||||||
Self::new(self.re.sqrt(), self.im) | ||||||||||
pub fn sqrt(mut self) -> Self { | ||||||||||
// complex sqrt algorithm based on the algorithm from | ||||||||||
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks | ||||||||||
// to increase accuracy. Compared to a naive implementationt that | ||||||||||
// reuses the complex exp/ln implementations this algorithm has better | ||||||||||
// accuarcy since both (real) sqrt and (real) hypot are garunteed to | ||||||||||
// round perfectly. It's also faster since this implementation requires | ||||||||||
// less transcendental functions and those it does use (sqrt/hypto) are | ||||||||||
// faster comparted to exp/sin/cos. | ||||||||||
// | ||||||||||
// The musl libc implementation was referenced while implementing the | ||||||||||
// algorithm here: | ||||||||||
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c | ||||||||||
|
||||||||||
// TODO: rounding for very tiny subnormal numbers isn't perfect yet so | ||||||||||
// the assert shown fails in the very worst case this leads to about | ||||||||||
// 10% accuracy loss (see example below). As the magnitude increase the | ||||||||||
// error quickly drops to basically zero. | ||||||||||
// | ||||||||||
// glibc handles that (but other implementations like musl and numpy do | ||||||||||
// not) by upscaling very small values. That upscaling (and particularly | ||||||||||
// it's reversal) are weird and hard to understand (and rely on mantissa | ||||||||||
// bit size which we can't get out of the trait). In general the glibc | ||||||||||
// implementation is ever so subtley different and I wouldn't want to | ||||||||||
// introduce bugs by trying to adapt the underflow handling. | ||||||||||
// | ||||||||||
// assert_eq!( | ||||||||||
// Complex64::new(5.212e-324, 5.212e-324).sqrt(), | ||||||||||
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162) | ||||||||||
// ); | ||||||||||
|
||||||||||
// specical cases for correct nan/inf handling | ||||||||||
// see https://en.cppreference.com/w/c/numeric/complex/csqrt | ||||||||||
|
||||||||||
if self.re.is_zero() && self.im.is_zero() { | ||||||||||
// 0 +/- 0 i | ||||||||||
return Self::new(T::zero(), self.im); | ||||||||||
} | ||||||||||
if self.im.is_infinite() { | ||||||||||
// inf +/- inf i | ||||||||||
return Self::new(T::infinity(), self.im); | ||||||||||
} | ||||||||||
if self.re.is_nan() { | ||||||||||
// nan + nan i | ||||||||||
return Self::new(self.re, T::nan()); | ||||||||||
} | ||||||||||
if self.re.is_infinite() { | ||||||||||
// √(inf +/- NaN i) = inf +/- NaN i | ||||||||||
// √(inf +/- x i) = inf +/- 0 i | ||||||||||
// √(-inf +/- NaN i) = NaN +/- inf i | ||||||||||
// √(-inf +/- x i) = 0 +/- inf i | ||||||||||
|
||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a variable to make this clearer:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that is indeed more readable, I also added a comments. good point |
||||||||||
// if im is inf (or nan) this is nan, otherwise it's zero | ||||||||||
#[allow(clippy::eq_op)] | ||||||||||
let zero_or_nan = self.im - self.im; | ||||||||||
if self.re.is_sign_negative() { | ||||||||||
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im)); | ||||||||||
} else { | ||||||||||
// √(r e^(iπ)) = √r e^(iπ/2) = i√r | ||||||||||
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r | ||||||||||
let re = T::zero(); | ||||||||||
let im = (-self.re).sqrt(); | ||||||||||
if self.im.is_sign_positive() { | ||||||||||
Self::new(re, im) | ||||||||||
} else { | ||||||||||
Self::new(re, -im) | ||||||||||
} | ||||||||||
} | ||||||||||
} else if self.re.is_zero() { | ||||||||||
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2) | ||||||||||
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2) | ||||||||||
let one = T::one(); | ||||||||||
let two = one + one; | ||||||||||
let x = (self.im.abs() / two).sqrt(); | ||||||||||
if self.im.is_sign_positive() { | ||||||||||
Self::new(x, x) | ||||||||||
} else { | ||||||||||
Self::new(x, -x) | ||||||||||
return Self::new(self.re, zero_or_nan.copysign(self.im)); | ||||||||||
} | ||||||||||
} | ||||||||||
let two = T::one() + T::one(); | ||||||||||
let four = two + two; | ||||||||||
let overflow = T::max_value() / (T::one() + T::sqrt(two)); | ||||||||||
let max_magnitude = self.re.abs().max(self.im.abs()); | ||||||||||
let scale = max_magnitude >= overflow; | ||||||||||
if scale { | ||||||||||
self = self / four; | ||||||||||
} | ||||||||||
if self.re.is_sign_negative() { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could also use a citation and link in a comment for the algorithm you mentioned. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a citation to the algorithm and the musl libc implementation as well as provide some additional background in a comement |
||||||||||
let tmp = ((-self.re + self.norm()) / two).sqrt(); | ||||||||||
self.re = self.im.abs() / (two * tmp); | ||||||||||
self.im = tmp.copysign(self.im); | ||||||||||
} else { | ||||||||||
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2) | ||||||||||
let one = T::one(); | ||||||||||
let two = one + one; | ||||||||||
let (r, theta) = self.to_polar(); | ||||||||||
Self::from_polar(r.sqrt(), theta / two) | ||||||||||
self.re = ((self.re + self.norm()) / two).sqrt(); | ||||||||||
self.im = self.im / (two * self.re); | ||||||||||
} | ||||||||||
if scale { | ||||||||||
self = self * two; | ||||||||||
} | ||||||||||
self | ||||||||||
} | ||||||||||
|
||||||||||
/// Computes the principal value of the cube root of `self`. | ||||||||||
|
@@ -2065,6 +2112,164 @@ pub(crate) mod test { | |||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
#[test] | ||||||||||
fn test_sqrt_nan() { | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::INFINITY, f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::NAN), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NAN, f64::INFINITY).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::NEG_INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(-0.0, 0.0).sqrt(), | ||||||||||
Complex64::new(0.0, 0.0), | ||||||||||
)); | ||||||||||
for x in (-100..100).map(f64::from) { | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(x, f64::INFINITY).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, f64::INFINITY), | ||||||||||
)); | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NAN, x).sqrt(), | ||||||||||
Complex64::new(f64::NAN, f64::NAN), | ||||||||||
)); | ||||||||||
// √(inf + x i) = inf + 0 i | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::INFINITY, x).sqrt(), | ||||||||||
Complex64::new(f64::INFINITY, 0.0.copysign(x)), | ||||||||||
)); | ||||||||||
// √(-inf + x i) = 0 + inf i | ||||||||||
assert!(close_naninf( | ||||||||||
Complex64::new(f64::NEG_INFINITY, x).sqrt(), | ||||||||||
Complex64::new(0.0, f64::INFINITY.copysign(x)), | ||||||||||
)); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
fn test_sqrt_rounding() { | ||||||||||
fn naive_sqrt(c: Complex64) -> Complex64 { | ||||||||||
let (r, theta) = c.to_polar(); | ||||||||||
Complex64::from_polar(r.sqrt(), theta / 2.0) | ||||||||||
} | ||||||||||
|
||||||||||
fn ulp_l1(a: Complex64, b: Complex64) -> u64 { | ||||||||||
let re_ulp = a.re.to_bits().abs_diff(b.re.to_bits()); | ||||||||||
let im_ulp = a.im.to_bits().abs_diff(b.im.to_bits()); | ||||||||||
re_ulp + im_ulp | ||||||||||
} | ||||||||||
fn close_to_ulp(a: Complex64, b: Complex64, ulp: usize) -> bool { | ||||||||||
ulp_l1(a, b) <= ulp as u64 | ||||||||||
} | ||||||||||
|
||||||||||
#[track_caller] | ||||||||||
fn check_sqrt(re: f64, im: f64, exact_sqrt_re: f64, exact_sqrt_im: f64) { | ||||||||||
let sqrt = Complex::new(re, im).sqrt(); | ||||||||||
assert_eq!(sqrt, Complex::new(exact_sqrt_re, exact_sqrt_im)); | ||||||||||
let naive_sqrt = naive_sqrt(Complex::new(re, im)); | ||||||||||
assert_ne!(naive_sqrt, sqrt, "invalid testcase {re} {im}"); | ||||||||||
let roundtrip = sqrt * sqrt; | ||||||||||
let naive_roundtrip = naive_sqrt * naive_sqrt; | ||||||||||
assert!( | ||||||||||
ulp_l1(roundtrip, Complex::new(re, im)) | ||||||||||
<= ulp_l1(naive_roundtrip, Complex::new(re, im)), | ||||||||||
"{} {} {}", | ||||||||||
Complex::new(re, im), | ||||||||||
roundtrip, | ||||||||||
naive_roundtrip | ||||||||||
) | ||||||||||
} | ||||||||||
|
||||||||||
#[track_caller] | ||||||||||
fn check_sqrt_roundtrip(re: f64, im: f64, ulp: usize) { | ||||||||||
let sqrt = Complex::new(re, im).sqrt(); | ||||||||||
let roundtrip = sqrt * sqrt; | ||||||||||
assert!( | ||||||||||
close_to_ulp(roundtrip, Complex::new(re, im), ulp), | ||||||||||
"roundtrip failed for {re} + j{im}: {roundtrip}" | ||||||||||
); | ||||||||||
let naive_sqrt = naive_sqrt(Complex::new(re, im)); | ||||||||||
let naive_roundtrip = naive_sqrt * naive_sqrt; | ||||||||||
assert!( | ||||||||||
!close_to_ulp(naive_roundtrip, Complex::new(re, im), ulp), | ||||||||||
"invalid testcase {re} + j{im} {naive_roundtrip} {roundtrip}" | ||||||||||
); | ||||||||||
} | ||||||||||
|
||||||||||
// some hand-collected testcases that roundtrip perfectly with a | ||||||||||
// sophisticated sqrt implementation but not a naive one .This can | ||||||||||
// look a bit cherry picked (and it is) but during all my cherry | ||||||||||
// picking i didn't find a single case which had worse rounding | ||||||||||
check_sqrt_roundtrip(-1e200, 1e100, 0); | ||||||||||
// with naive implementation there is an error in both re and im part | ||||||||||
// but with the implementation here only on the re part | ||||||||||
check_sqrt_roundtrip(1.0 / 3.0, 1.0 / 3.0, 1); | ||||||||||
check_sqrt_roundtrip(-1.0 / 3.0, 1.0 / 3.0, 1); | ||||||||||
check_sqrt_roundtrip(-0.2, 0.1, 1); | ||||||||||
check_sqrt_roundtrip(-0.45, 0.1, 1); | ||||||||||
check_sqrt_roundtrip(-std::f64::consts::TAU, std::f64::consts::PI, 1); | ||||||||||
// both algorithms don't have the strongest showing here (8 ulp vs 9) but | ||||||||||
// 0.0999999999999999-0.45i instead of 0.10000000000000012-0.45000000000000007i | ||||||||||
// seems much better since the error is only in the re (and not im) | ||||||||||
check_sqrt_roundtrip(0.1, -0.45, 8); | ||||||||||
|
||||||||||
// reference values were computed with numpy but are identical | ||||||||||
// with musl and glibc, showing that we round correctly both | ||||||||||
// in reasonable ranges and extremes cases. All of these tests | ||||||||||
// fail with a naive sqrt implementation based on phase shift (this | ||||||||||
// is checked as part of the tests). | ||||||||||
// | ||||||||||
// The testcases were generated by the following python script: | ||||||||||
// | ||||||||||
// import numpy as np | ||||||||||
// vals = [ | ||||||||||
// (0.1, 0.1), | ||||||||||
// (0.1, 1 / 3), | ||||||||||
// (1 / 3, 0.1), | ||||||||||
// (1 / 3, 1 / 3), | ||||||||||
// (1.1, 1e-100), | ||||||||||
// (1e-100, 0.1), | ||||||||||
// (1e-100, 1.1), | ||||||||||
// (1e-100, 1e-100), | ||||||||||
// ] | ||||||||||
// for re, im in vals: | ||||||||||
// reference = np.sqrt(re + im * 1j) | ||||||||||
// print(f"check_sqrt({re}, {im}, {reference.real}, {reference.imag});") | ||||||||||
|
||||||||||
check_sqrt(0.1, 0.1, 0.34743442276011566, 0.14391204994250742); | ||||||||||
check_sqrt( | ||||||||||
0.1, | ||||||||||
0.3333333333333333, | ||||||||||
0.4732917794361556, | ||||||||||
0.3521435907152684, | ||||||||||
); | ||||||||||
check_sqrt( | ||||||||||
0.3333333333333333, | ||||||||||
0.1, | ||||||||||
0.5836709476652998, | ||||||||||
0.08566470577300687, | ||||||||||
); | ||||||||||
check_sqrt( | ||||||||||
0.3333333333333333, | ||||||||||
0.3333333333333333, | ||||||||||
0.6343255686650054, | ||||||||||
0.26274625350107117, | ||||||||||
); | ||||||||||
check_sqrt(1.1, 1e-100, 1.0488088481701516, 4.767312946227961e-101); | ||||||||||
check_sqrt(1e-100, 1e-100, 1.09868411346781e-50, 4.550898605622274e-51); | ||||||||||
check_sqrt(0.1, -0.45, 0.5296117553758811, -0.4248395125601222); | ||||||||||
} | ||||||||||
|
||||||||||
#[test] | ||||||||||
fn test_cbrt() { | ||||||||||
assert!(close(_0_0i.cbrt(), _0_0i)); | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a source for all these special cases? e.g.
https://en.cppreference.com/w/c/numeric/complex/csqrt
(and make sure all those are covered)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more test to
test_nan
to make sure all of these are covered by theses and added a comment