diff --git a/CHANGELOG.md b/CHANGELOG.md index fca9d4807f..c4c99d348b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. - Rename `rand::distributions` to `rand::distr` (#1470) - The `serde1` feature has been renamed `serde` (#1477) - Mark `WeightError`, `PoissonError`, `BinomialError` as `#[non_exhaustive]` (#1480). +- Refactor inverse `Binomial` algorithm to permit for n > i32::MAX values (#1486). ## [0.9.0-alpha.1] - 2024-03-18 - Add the `Slice::num_choices` method to the Slice distribution (#1402) diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 885d8b21c3..58575398c4 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -123,6 +123,7 @@ impl Distribution for Binomial { let result; let q = 1. - p; + let np = (self.n as f64) * p; // For small n * min(p, 1 - p), the BINV algorithm based on the inverse // transformation of the binomial distribution is efficient. Otherwise, @@ -136,19 +137,65 @@ impl Distribution for Binomial { // Ranlib uses 30, and GSL uses 14. const BINV_THRESHOLD: f64 = 10.; + // This threshold is when powi outperforms the .exp() .ln() method. + // However it's constrained by i32::MAX from powi and performs worse above this threshold. + // This value can likely be more finely optimized, but should be done across multiple hardware and in a more controlled setting. + // It's also such an edge case that very few people are likely to benefit from it. + const SMALL_NP_THRESHOLD: f64 = 1e-10; + // Same value as in GSL. // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. const BINV_MAX_X: u64 = 110; - if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) { + let mut r: f64; + if self.n == 1 { + // Use the BINV algorithm for special case n = 1 (simplify r calculations). + let s: f64 = p / q; + + result = 'outer: loop { + r = q; + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; + } + r *= (((2 - x) as f64) * s) / (x as f64); + } + break x; + } + } else if np < SMALL_NP_THRESHOLD && self.n <= (i32::MAX as u64) { + // For very small n*p the powi is superior. + // Use the BINV algorithm. + let s: f64 = p / q; + + result = 'outer: loop { + r = q.powi(self.n as i32); + let mut u: f64 = rng.random(); + let mut x = 0; + + while u > r { + u -= r; + x += 1; + if x > BINV_MAX_X { + continue 'outer; + } + r *= (((self.n - x + 1) as f64) * s) / (x as f64); + } + break x; + } + } else if np < BINV_THRESHOLD { + // For everything else r = (q.ln() * (self.n as f64)).exp() is superior. // Use the BINV algorithm. - let s = p / q; - let a = ((self.n + 1) as f64) * s; + let s: f64 = p / q; result = 'outer: loop { - let mut r = q.powi(self.n as i32); + r = (q.ln() * (self.n as f64)).exp(); let mut u: f64 = rng.random(); let mut x = 0; @@ -158,7 +205,7 @@ impl Distribution for Binomial { if x > BINV_MAX_X { continue 'outer; } - r *= a / (x as f64) - s; + r *= (((self.n - x + 1) as f64) * s) / (x as f64); } break x; }