Skip to content

Commit

Permalink
Keep calm and carry on
Browse files Browse the repository at this point in the history
See GitHub issue statrs-dev#41.

Don't panic when calculating a pdf, cdf or pmf whenever the input is a
value that the distribution cannot attain. The pdf, pmf and cdf are
still defined at those inputs and we should return the correct values.

This commit also adds a few tests to each distribution to test that
pdf(-inf)=0, pdf(inf)=0, cdf(-inf)=0 and cdf(inf)=1. It also uses simple
numerical integration to test for continuous distributions that the
integral of the pdf is approximately equal to the cdf, and for discrete
distributions that the sum of the pmf is equal to the cdf.
  • Loading branch information
michiel-de-muynck committed Jun 2, 2017
1 parent 7c612a7 commit 0b78ee2
Show file tree
Hide file tree
Showing 23 changed files with 546 additions and 437 deletions.
12 changes: 0 additions & 12 deletions src/distribution/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ impl Univariate<u64, f64> for Bernoulli {
/// Calculates the cumulative distribution
/// function for the bernoulli distribution at `x`.
///
/// # Panics
///
/// If `x < 0.0` or `x > 1.0`
///
/// # Formula
///
/// ```ignore
Expand Down Expand Up @@ -277,10 +273,6 @@ impl Discrete<u64, f64> for Bernoulli {
/// Calculates the probability mass function for the
/// bernoulli distribution at `x`.
///
/// # Panics
///
/// If `x > 1`
///
/// # Formula
///
/// ```ignore
Expand All @@ -294,10 +286,6 @@ impl Discrete<u64, f64> for Bernoulli {
/// Calculates the log probability mass function for the
/// bernoulli distribution at `x`.
///
/// # Panics
///
/// If `x > 1`
///
/// # Formula
///
/// ```ignore
Expand Down
57 changes: 23 additions & 34 deletions src/distribution/beta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,6 @@ impl Univariate<f64, f64> for Beta {
/// Calculates the cumulative distribution function for the beta distribution
/// at `x`
///
/// # Panics
///
/// If `x < 0.0` or `x > 1.0`
///
/// # Formula
///
/// ```ignore
Expand All @@ -149,9 +145,9 @@ impl Univariate<f64, f64> for Beta {
/// where `α` is shapeA, `β` is shapeB, and `I_x` is the regularized
/// lower incomplete beta function
fn cdf(&self, x: f64) -> f64 {
assert!(x >= 0.0 && x <= 1.0,
format!("{}", StatsError::ArgIntervalIncl("x", 0.0, 1.0)));
if x == 1.0 {
if x < 0.0 {
0.0
} else if x >= 1.0 {
1.0
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x < 0.5 { 0.0 } else { 1.0 }
Expand Down Expand Up @@ -348,10 +344,6 @@ impl Mode<f64> for Beta {
impl Continuous<f64, f64> for Beta {
/// Calculates the probability density function for the beta distribution at `x`.
///
/// # Panics
///
/// If `x < 0.0` or `x > 1.0`
///
/// # Formula
///
/// ```ignore
Expand All @@ -362,9 +354,9 @@ impl Continuous<f64, f64> for Beta {
///
/// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function
fn pdf(&self, x: f64) -> f64 {
assert!(x >= 0.0 && x <= 1.0,
format!("{}", StatsError::ArgIntervalIncl("x", 0.0, 1.0)));
if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x < 0.0 || x > 1.0 {
0.0
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x == 0.5 { f64::INFINITY } else { 0.0 }
} else if self.shape_a == f64::INFINITY {
if x == 1.0 { f64::INFINITY } else { 0.0 }
Expand All @@ -383,10 +375,6 @@ impl Continuous<f64, f64> for Beta {

/// Calculates the log probability density function for the beta distribution at `x`.
///
/// # Panics
///
/// If `x < 0.0` or `x > 1.0`
///
/// # Formula
///
/// ```ignore
Expand All @@ -397,9 +385,9 @@ impl Continuous<f64, f64> for Beta {
///
/// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function
fn ln_pdf(&self, x: f64) -> f64 {
assert!(x >= 0.0 && x <= 1.0,
format!("{}", StatsError::ArgIntervalIncl("x", 0.0, 1.0)));
if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x < 0.0 || x > 1.0 {
f64::NEG_INFINITY
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x == 0.5 {
f64::INFINITY
} else {
Expand Down Expand Up @@ -443,6 +431,7 @@ mod test {
use std::f64;
use statistics::*;
use distribution::{Univariate, Continuous, Beta};
use distribution::internal::*;

fn try_create(shape_a: f64, shape_b: f64) -> Beta {
let n = Beta::new(shape_a, shape_b);
Expand Down Expand Up @@ -612,15 +601,13 @@ mod test {
}

#[test]
#[should_panic]
fn test_pdf_input_lt_zero() {
get_value(1.0, 1.0, |x| x.pdf(-1.0));
test_case(1.0, 1.0, 0.0, |x| x.pdf(-1.0));
}

#[test]
#[should_panic]
fn test_pdf_input_gt_one() {
get_value(1.0, 1.0, |x| x.pdf(2.0));
test_case(1.0, 1.0, 0.0, |x| x.pdf(2.0));
}

#[test]
Expand All @@ -646,15 +633,13 @@ mod test {
}

#[test]
#[should_panic]
fn test_ln_pdf_input_lt_zero() {
get_value(1.0, 1.0, |x| x.ln_pdf(-1.0));
test_case(1.0, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(-1.0));
}

#[test]
#[should_panic]
fn test_ln_pdf_input_gt_one() {
get_value(1.0, 1.0, |x| x.ln_pdf(2.0));
test_case(1.0, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(2.0));
}

#[test]
Expand All @@ -680,14 +665,18 @@ mod test {
}

#[test]
#[should_panic]
fn test_cdf_input_lt_zero() {
get_value(1.0, 1.0, |x| x.cdf(-1.0));
test_case(1.0, 1.0, 0.0, |x| x.cdf(-1.0));
}

#[test]
#[should_panic]
fn test_cdf_input_gt_zero() {
get_value(1.0, 1.0, |x| x.cdf(2.0));
}
test_case(1.0, 1.0, 1.0, |x| x.cdf(2.0));
}

#[test]
fn test_continuous() {
test::check_continuous_distribution(&try_create(1.2, 3.4), 0.0, 1.0);
test::check_continuous_distribution(&try_create(4.5, 6.7), 0.0, 1.0);
}
}
47 changes: 21 additions & 26 deletions src/distribution/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,6 @@ impl Univariate<u64, f64> for Binomial {
/// Calulcates the cumulative distribution function for the
/// binomial distribution at `x`
///
/// # Panics
///
/// If `x < 0.0` or `x > n`
///
/// # Formula
///
/// ```ignore
Expand All @@ -147,11 +143,11 @@ impl Univariate<u64, f64> for Binomial {
///
/// where `I_(x)(a, b)` is the regularized incomplete beta function
fn cdf(&self, x: f64) -> f64 {
assert!(x >= 0.0 && x <= self.n as f64,
format!("{}", StatsError::ArgIntervalIncl("x", 0.0, self.n as f64)));
if x == self.n as f64 {
1.0
} else {
if x < 0.0 {
0.0
} else if x >= self.n as f64 {
1.0
} else {
let k = x.floor();
beta::beta_reg(self.n as f64 - k, k + 1.0, 1.0 - self.p)
}
Expand Down Expand Up @@ -293,18 +289,15 @@ impl Discrete<u64, f64> for Binomial {
/// Calculates the probability mass function for the binomial
/// distribution at `x`
///
/// # Panics
///
/// If `x > n`
///
/// # Formula
///
/// ```ignore
/// (n choose k) * p^k * (1 - p)^(n - k)
/// ```
fn pmf(&self, x: u64) -> f64 {
assert!(x <= self.n,
format!("{}", StatsError::ArgLte("x", 1.0)));
if x > self.n {
return 0.0;
}
match self.p {
0.0 if x == 0 => 1.0,
0.0 => 0.0,
Expand All @@ -321,18 +314,15 @@ impl Discrete<u64, f64> for Binomial {
/// Calculates the log probability mass function for the binomial
/// distribution at `x`
///
/// # Panics
///
/// If `x > n`
///
/// # Formula
///
/// ```ignore
/// ln((n choose k) * p^k * (1 - p)^(n - k))
/// ```
fn ln_pmf(&self, x: u64) -> f64 {
assert!(x <= self.n,
format!("{}", StatsError::ArgLte("x", 1.0)));
if x > self.n {
return f64::NEG_INFINITY;
}
match self.p {
0.0 if x == 0 => 0.0,
0.0 => f64::NEG_INFINITY,
Expand All @@ -353,6 +343,7 @@ mod test {
use std::f64;
use statistics::*;
use distribution::{Univariate, Discrete, Binomial};
use distribution::internal::*;

fn try_create(p: f64, n: u64) -> Binomial {
let n = Binomial::new(p, n);
Expand Down Expand Up @@ -548,14 +539,18 @@ mod test {
}

#[test]
#[should_panic]
fn test_cdf_lower_bound() {
get_value(0.5, 3, |x| x.cdf(-1.0));
test_case(0.5, 3, 0.0, |x| x.cdf(-1.0));
}

#[test]
#[should_panic]
fn test_cdf_upper_bound() {
get_value(0.5, 3, |x| x.cdf(5.0));
}
test_case(0.5, 3, 1.0, |x| x.cdf(5.0));
}

#[test]
fn test_discrete() {
test::check_discrete_distribution(&try_create(0.3, 5), 5);
test::check_discrete_distribution(&try_create(0.7, 10), 10);
}
}
48 changes: 22 additions & 26 deletions src/distribution/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,6 @@ impl Univariate<u64, f64> for Categorical {
/// Calculates the cumulative distribution function for the categorical
/// distribution at `x`
///
/// # Panics
///
/// If `x < 0.0` or `x > k` where `k` is the number of categories
/// (i.e. the length of the `prob_mass` slice passed to the constructor)
///
/// # Formula
///
/// ```ignore
Expand All @@ -136,12 +131,11 @@ impl Univariate<u64, f64> for Categorical {
///
/// where `p_j` is the probability mass for the `j`th category
fn cdf(&self, x: f64) -> f64 {
assert!(x >= 0.0 && x <= self.cdf.len() as f64,
format!("{}",
StatsError::ArgIntervalIncl("x", 0.0, self.cdf.len() as f64)));
if x == self.cdf.len() as f64 {
1.0
} else {
if x < 0.0 {
0.0
} else if x >= self.cdf.len() as f64 {
1.0
} else {
unsafe { self.cdf.get_unchecked(x as usize) / self.cdf_max() }
}
}
Expand Down Expand Up @@ -269,19 +263,17 @@ impl Discrete<u64, f64> for Categorical {
/// Calculates the probability mass function for the categorical
/// distribution at `x`
///
/// # Panics
///
/// If `x >= k` where `k` is the number of categories
///
/// # Formula
///
/// ```ignore
/// p_x
/// ```
fn pmf(&self, x: u64) -> f64 {
assert!(x < self.norm_pmf.len() as u64,
format!("{}", StatsError::ArgLtArg("x", "k")));
unsafe { *self.norm_pmf.get_unchecked(x as usize) }
if x >= self.norm_pmf.len() as u64 {
0.0
} else {
unsafe { *self.norm_pmf.get_unchecked(x as usize) }
}
}

/// Calculates the log probability mass function for the categorical
Expand Down Expand Up @@ -373,9 +365,11 @@ fn test_binary_index() {
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(test)]
mod test {
use std::f64;
use std::fmt::Debug;
use statistics::*;
use distribution::{Univariate, Discrete, InverseCDF, Categorical};
use distribution::internal::*;

fn try_create(prob_mass: &[f64]) -> Categorical {
let n = Categorical::new(prob_mass);
Expand Down Expand Up @@ -466,9 +460,8 @@ mod test {
}

#[test]
#[should_panic]
fn test_pmf_x_too_high() {
get_value(&[4.0, 2.5, 2.5, 1.0], |x| x.pmf(4));
test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, |x| x.pmf(4));
}

#[test]
Expand All @@ -479,9 +472,8 @@ mod test {
}

#[test]
#[should_panic]
fn test_ln_pmf_x_too_high() {
get_value(&[4.0, 2.5, 2.5, 1.0], |x| x.ln_pmf(4));
test_case(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, |x| x.ln_pmf(4));
}

#[test]
Expand All @@ -494,15 +486,13 @@ mod test {
}

#[test]
#[should_panic]
fn test_cdf_input_low() {
get_value(&[4.0, 2.5, 2.5, 1.0], |x| x.cdf(-1.0));
test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, |x| x.cdf(-1.0));
}

#[test]
#[should_panic]
fn test_cdf_input_high() {
get_value(&[4.0, 2.5, 2.5, 1.0], |x| x.cdf(4.5));
test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, |x| x.cdf(4.5));
}

#[test]
Expand All @@ -526,4 +516,10 @@ mod test {
fn test_inverse_cdf_input_high() {
get_value(&[4.0, 2.5, 2.5, 1.0], |x| x.inverse_cdf(1.0));
}

#[test]
fn test_discrete() {
test::check_discrete_distribution(&try_create(&[1.0, 2.0, 3.0, 4.0]), 4);
test::check_discrete_distribution(&try_create(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5);
}
}
Loading

0 comments on commit 0b78ee2

Please sign in to comment.