Skip to content

Commit

Permalink
rand_distr: Fix dirichlet sample method for small alpha.
Browse files Browse the repository at this point in the history
Generating Dirichlet samples using the method based on samples from
the gamma distribution can result in samples being nan if all the
values in alpha are sufficiently small.  The fix is to instead use
the method based on the marginal distributions being the beta
distribution (i.e. the "stick breaking" method) when all values in
alpha are small.
  • Loading branch information
WarrenWeckesser committed Jan 2, 2022
1 parent 19404d6 commit 745ace8
Showing 1 changed file with 120 additions and 11 deletions.
131 changes: 120 additions & 11 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

//! The dirichlet distribution.
#![cfg(feature = "alloc")]
use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use num_traits::{Float, NumCast};
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use core::fmt;
use alloc::{boxed::Box, vec, vec::Vec};
Expand Down Expand Up @@ -123,16 +123,56 @@ where
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![F::zero(); n];
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum = sum + (*s);
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s = (*s)*invacc;
if self.alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) {
// All the values in alpha are less than 0.1.
//
// When all the alpha parameters are sufficiently small, there
// is a nontrivial probability that the samples from the gamma
// distributions used in the other method will all be 0, which
// results in the dirichlet samples being nan. So instead of
// use that method, use the "stick breaking" method based on the
// marginal beta distributions.
//
// Form the right-to-left cumulative sum of alpha, exluding the
// first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
// after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
// will hold [a1+a2+a3, a2+a3, a3].
let mut alpha_sum_rl: Vec<F> = self
.alpha
.iter()
.skip(1)
.rev()
// scan does the cumulative sum
.scan(F::zero(), |sum, x| {
*sum = *sum + *x;
Some(*sum)
})
.collect();
alpha_sum_rl.reverse();
let mut acc = F::one();
for ((s, &a), &b) in samples
.iter_mut()
.zip(self.alpha.iter())
.zip(alpha_sum_rl.iter())
{
let beta = Beta::new(a, b).unwrap();
let beta_sample = beta.sample(rng);
*s = acc * beta_sample;
acc = acc * (F::one() - beta_sample);
}
samples[n - 1] = acc;
} else {
let mut sum = F::zero();
for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum = sum + (*s);
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s = (*s) * invacc;
}
}
samples
}
Expand All @@ -142,6 +182,33 @@ where
mod test {
use super::*;

//
// Check that the means of the components of n samples from
// the Dirichlet distribution agree with the expected means
// with a relative tolerance of rtol.
//
// This is a crude statistical test, but it will catch egregious
// mistakes. It will also also fail if any samples contain nan.
//
fn check_dirichlet_means(alpha: &Vec<f64>, n: i32, rtol: f64, seed: u64) {
let d = Dirichlet::new(&alpha).unwrap();
let alpha_len = d.alpha.len();
let mut rng = crate::test::rng(seed);
let mut sums = vec![0.0; alpha_len];
for _ in 0..n {
let samples = d.sample(&mut rng);
for i in 0..alpha_len {
sums[i] += samples[i];
}
}
let sample_mean: Vec<f64> = sums.iter().map(|x| x / n as f64).collect();
let alpha_sum: f64 = d.alpha.iter().sum();
let expected_mean: Vec<f64> = d.alpha.iter().map(|x| x / alpha_sum).collect();
for i in 0..alpha_len {
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
}
}

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
Expand Down Expand Up @@ -172,6 +239,48 @@ mod test {
.collect();
}

#[test]
fn test_dirichlet_means() {
// Check the means of 20000 samples for several different alphas.
let alpha_set = vec![
vec![0.5, 0.25],
vec![123.0, 75.0],
vec![2.0, 2.5, 5.0, 7.0],
vec![0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5],
];
let n = 20000;
let rtol = 2e-2;
let seed = 1317624576693539401;
for alpha in alpha_set {
check_dirichlet_means(&alpha, n, rtol, seed);
}
}

#[test]
fn test_dirichlet_means_very_small_alpha() {
// With values of alpha that are all 0.001, check that the means of the
// components of 10000 samples are within 1% of the expected means.
// With the sampling method based on gamma variates, this test would
// fail, with about 10% of the samples containing nan.
let alpha = vec![0.001, 0.001, 0.001];
let n = 10000;
let rtol = 1e-2;
let seed = 1317624576693539401;
check_dirichlet_means(&alpha, n, rtol, seed);
}

#[test]
fn test_dirichlet_means_small_alpha() {
// With values of alpha that are all less than 0.1, check that the
// means of the components of 150000 samples are within 0.1% of the
// expected means.
let alpha = vec![0.05, 0.025, 0.075, 0.05];
let n = 150000;
let rtol = 1e-3;
let seed = 1317624576693539401;
check_dirichlet_means(&alpha, n, rtol, seed);
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Expand Down

0 comments on commit 745ace8

Please sign in to comment.