Skip to content

more accurate sqrt function #129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 235 additions & 30 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
///
/// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
#[inline]
pub fn sqrt(self) -> Self {
if self.im.is_zero() {
if self.re.is_sign_positive() {
// simple positive real √r, and copy `im` for its sign
Self::new(self.re.sqrt(), self.im)
pub fn sqrt(mut self) -> Self {
// complex sqrt algorithm based on the algorithm from
// dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
// to increase accuracy. Compared to a naive implementationt that
// reuses the complex exp/ln implementations this algorithm has better
// accuarcy since both (real) sqrt and (real) hypot are garunteed to
// round perfectly. It's also faster since this implementation requires
// less transcendental functions and those it does use (sqrt/hypto) are
// faster comparted to exp/sin/cos.
//
// The musl libc implementation was referenced while implementing the
// algorithm here:
// https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c

// TODO: rounding for very tiny subnormal numbers isn't perfect yet so
// the assert shown fails in the very worst case this leads to about
// 10% accuracy loss (see example below). As the magnitude increase the
// error quickly drops to basically zero.
//
// glibc handles that (but other implementations like musl and numpy do
// not) by upscaling very small values. That upscaling (and particularly
// it's reversal) are weird and hard to understand (and rely on mantissa
// bit size which we can't get out of the trait). In general the glibc
// implementation is ever so subtley different and I wouldn't want to
// introduce bugs by trying to adapt the underflow handling.
//
// assert_eq!(
// Complex64::new(5.212e-324, 5.212e-324).sqrt(),
// Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
// );

// specical cases for correct nan/inf handling
// see https://en.cppreference.com/w/c/numeric/complex/csqrt

if self.re.is_zero() && self.im.is_zero() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a source for all these special cases? e.g.
https://en.cppreference.com/w/c/numeric/complex/csqrt
(and make sure all those are covered)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test to test_nan to make sure all of these are covered by theses and added a comment

// 0 +/- 0 i
return Self::new(T::zero(), self.im);
}
if self.im.is_infinite() {
// inf +/- inf i
return Self::new(T::infinity(), self.im);
}
if self.re.is_nan() {
// nan + nan i
return Self::new(self.re, T::nan());
}
if self.re.is_infinite() {
// √(inf +/- NaN i) = inf +/- NaN i
// √(inf +/- x i) = inf +/- 0 i
// √(-inf +/- NaN i) = NaN +/- inf i
// √(-inf +/- x i) = 0 +/- inf i

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a variable to make this clearer:

Suggested change
#[allow(clippy::eq_op)]
let zero_or_nan = self.im - self.im;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is indeed more readable, I also added a comments. good point

// if im is inf (or nan) this is nan, otherwise it's zero
#[allow(clippy::eq_op)]
let zero_or_nan = self.im - self.im;
if self.re.is_sign_negative() {
return Self::new(zero_or_nan.abs(), self.re.copysign(self.im));
} else {
// √(r e^(iπ)) = √r e^(iπ/2) = i√r
// √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
let re = T::zero();
let im = (-self.re).sqrt();
if self.im.is_sign_positive() {
Self::new(re, im)
} else {
Self::new(re, -im)
}
}
} else if self.re.is_zero() {
// √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
// √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
let one = T::one();
let two = one + one;
let x = (self.im.abs() / two).sqrt();
if self.im.is_sign_positive() {
Self::new(x, x)
} else {
Self::new(x, -x)
return Self::new(self.re, zero_or_nan.copysign(self.im));
}
}
let two = T::one() + T::one();
let four = two + two;
let overflow = T::max_value() / (T::one() + T::sqrt(two));
let max_magnitude = self.re.abs().max(self.im.abs());
let scale = max_magnitude >= overflow;
if scale {
self = self / four;
}
if self.re.is_sign_negative() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also use a citation and link in a comment for the algorithm you mentioned.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a citation to the algorithm and the musl libc implementation as well as provide some additional background in a comement

let tmp = ((-self.re + self.norm()) / two).sqrt();
self.re = self.im.abs() / (two * tmp);
self.im = tmp.copysign(self.im);
} else {
// formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
let one = T::one();
let two = one + one;
let (r, theta) = self.to_polar();
Self::from_polar(r.sqrt(), theta / two)
self.re = ((self.re + self.norm()) / two).sqrt();
self.im = self.im / (two * self.re);
}
if scale {
self = self * two;
}
self
}

/// Computes the principal value of the cube root of `self`.
Expand Down Expand Up @@ -2065,6 +2112,164 @@ pub(crate) mod test {
}
}

#[test]
fn test_sqrt_nan() {
assert!(close_naninf(
Complex64::new(f64::INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::INFINITY, f64::NAN),
));
assert!(close_naninf(
Complex64::new(f64::NAN, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, -f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::NEG_INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, f64::NAN).sqrt(),
Complex64::new(f64::NAN, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(-0.0, 0.0).sqrt(),
Complex64::new(0.0, 0.0),
));
for x in (-100..100).map(f64::from) {
assert!(close_naninf(
Complex64::new(x, f64::INFINITY).sqrt(),
Complex64::new(f64::INFINITY, f64::INFINITY),
));
assert!(close_naninf(
Complex64::new(f64::NAN, x).sqrt(),
Complex64::new(f64::NAN, f64::NAN),
));
// √(inf + x i) = inf + 0 i
assert!(close_naninf(
Complex64::new(f64::INFINITY, x).sqrt(),
Complex64::new(f64::INFINITY, 0.0.copysign(x)),
));
// √(-inf + x i) = 0 + inf i
assert!(close_naninf(
Complex64::new(f64::NEG_INFINITY, x).sqrt(),
Complex64::new(0.0, f64::INFINITY.copysign(x)),
));
}
}


fn test_sqrt_rounding() {
fn naive_sqrt(c: Complex64) -> Complex64 {
let (r, theta) = c.to_polar();
Complex64::from_polar(r.sqrt(), theta / 2.0)
}

fn ulp_l1(a: Complex64, b: Complex64) -> u64 {
let re_ulp = a.re.to_bits().abs_diff(b.re.to_bits());
let im_ulp = a.im.to_bits().abs_diff(b.im.to_bits());
re_ulp + im_ulp
}
fn close_to_ulp(a: Complex64, b: Complex64, ulp: usize) -> bool {
ulp_l1(a, b) <= ulp as u64
}

#[track_caller]
fn check_sqrt(re: f64, im: f64, exact_sqrt_re: f64, exact_sqrt_im: f64) {
let sqrt = Complex::new(re, im).sqrt();
assert_eq!(sqrt, Complex::new(exact_sqrt_re, exact_sqrt_im));
let naive_sqrt = naive_sqrt(Complex::new(re, im));
assert_ne!(naive_sqrt, sqrt, "invalid testcase {re} {im}");
let roundtrip = sqrt * sqrt;
let naive_roundtrip = naive_sqrt * naive_sqrt;
assert!(
ulp_l1(roundtrip, Complex::new(re, im))
<= ulp_l1(naive_roundtrip, Complex::new(re, im)),
"{} {} {}",
Complex::new(re, im),
roundtrip,
naive_roundtrip
)
}

#[track_caller]
fn check_sqrt_roundtrip(re: f64, im: f64, ulp: usize) {
let sqrt = Complex::new(re, im).sqrt();
let roundtrip = sqrt * sqrt;
assert!(
close_to_ulp(roundtrip, Complex::new(re, im), ulp),
"roundtrip failed for {re} + j{im}: {roundtrip}"
);
let naive_sqrt = naive_sqrt(Complex::new(re, im));
let naive_roundtrip = naive_sqrt * naive_sqrt;
assert!(
!close_to_ulp(naive_roundtrip, Complex::new(re, im), ulp),
"invalid testcase {re} + j{im} {naive_roundtrip} {roundtrip}"
);
}

// some hand-collected testcases that roundtrip perfectly with a
// sophisticated sqrt implementation but not a naive one .This can
// look a bit cherry picked (and it is) but during all my cherry
// picking i didn't find a single case which had worse rounding
check_sqrt_roundtrip(-1e200, 1e100, 0);
// with naive implementation there is an error in both re and im part
// but with the implementation here only on the re part
check_sqrt_roundtrip(1.0 / 3.0, 1.0 / 3.0, 1);
check_sqrt_roundtrip(-1.0 / 3.0, 1.0 / 3.0, 1);
check_sqrt_roundtrip(-0.2, 0.1, 1);
check_sqrt_roundtrip(-0.45, 0.1, 1);
check_sqrt_roundtrip(-std::f64::consts::TAU, std::f64::consts::PI, 1);
// both algorithms don't have the strongest showing here (8 ulp vs 9) but
// 0.0999999999999999-0.45i instead of 0.10000000000000012-0.45000000000000007i
// seems much better since the error is only in the re (and not im)
check_sqrt_roundtrip(0.1, -0.45, 8);

// reference values were computed with numpy but are identical
// with musl and glibc, showing that we round correctly both
// in reasonable ranges and extremes cases. All of these tests
// fail with a naive sqrt implementation based on phase shift (this
// is checked as part of the tests).
//
// The testcases were generated by the following python script:
//
// import numpy as np
// vals = [
// (0.1, 0.1),
// (0.1, 1 / 3),
// (1 / 3, 0.1),
// (1 / 3, 1 / 3),
// (1.1, 1e-100),
// (1e-100, 0.1),
// (1e-100, 1.1),
// (1e-100, 1e-100),
// ]
// for re, im in vals:
// reference = np.sqrt(re + im * 1j)
// print(f"check_sqrt({re}, {im}, {reference.real}, {reference.imag});")

check_sqrt(0.1, 0.1, 0.34743442276011566, 0.14391204994250742);
check_sqrt(
0.1,
0.3333333333333333,
0.4732917794361556,
0.3521435907152684,
);
check_sqrt(
0.3333333333333333,
0.1,
0.5836709476652998,
0.08566470577300687,
);
check_sqrt(
0.3333333333333333,
0.3333333333333333,
0.6343255686650054,
0.26274625350107117,
);
check_sqrt(1.1, 1e-100, 1.0488088481701516, 4.767312946227961e-101);
check_sqrt(1e-100, 1e-100, 1.09868411346781e-50, 4.550898605622274e-51);
check_sqrt(0.1, -0.45, 0.5296117553758811, -0.4248395125601222);
}

#[test]
fn test_cbrt() {
assert!(close(_0_0i.cbrt(), _0_0i));
Expand Down