Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rand_distr: Fix dirichlet sample method for small alpha. #1209

Merged
merged 9 commits into from
May 1, 2023
131 changes: 120 additions & 11 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

//! The dirichlet distribution.
#![cfg(feature = "alloc")]
use num_traits::Float;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use num_traits::{Float, NumCast};
use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use rand::Rng;
use core::fmt;
use alloc::{boxed::Box, vec, vec::Vec};
Expand Down Expand Up @@ -123,16 +123,56 @@ where
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
let n = self.alpha.len();
let mut samples = vec![F::zero(); n];
let mut sum = F::zero();

for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, F::one()).unwrap();
*s = g.sample(rng);
sum = sum + (*s);
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s = (*s)*invacc;
if self.alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) {
// All the values in alpha are less than 0.1.
//
// When all the alpha parameters are sufficiently small, there
// is a nontrivial probability that the samples from the gamma
// distributions used in the other method will all be 0, which
// results in the dirichlet samples being nan. So instead of
// use that method, use the "stick breaking" method based on the
// marginal beta distributions.
//
// Form the right-to-left cumulative sum of alpha, exluding the
// first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
// after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
// will hold [a1+a2+a3, a2+a3, a3].
let mut alpha_sum_rl: Vec<F> = self
.alpha
.iter()
.skip(1)
.rev()
// scan does the cumulative sum
.scan(F::zero(), |sum, x| {
*sum = *sum + *x;
Some(*sum)
})
.collect();
alpha_sum_rl.reverse();
vks marked this conversation as resolved.
Show resolved Hide resolved
let mut acc = F::one();
for ((s, &a), &b) in samples
.iter_mut()
.zip(self.alpha.iter())
.zip(alpha_sum_rl.iter())
{
let beta = Beta::new(a, b).unwrap();
vks marked this conversation as resolved.
Show resolved Hide resolved
let beta_sample = beta.sample(rng);
*s = acc * beta_sample;
acc = acc * (F::one() - beta_sample);
}
samples[n - 1] = acc;
} else {
let mut sum = F::zero();
for (s, &a) in samples.iter_mut().zip(self.alpha.iter()) {
let g = Gamma::new(a, F::one()).unwrap();
vks marked this conversation as resolved.
Show resolved Hide resolved
*s = g.sample(rng);
sum = sum + (*s);
}
let invacc = F::one() / sum;
for s in samples.iter_mut() {
*s = (*s) * invacc;
}
}
samples
}
Expand All @@ -142,6 +182,33 @@ where
mod test {
use super::*;

//
// Check that the means of the components of n samples from
// the Dirichlet distribution agree with the expected means
// with a relative tolerance of rtol.
//
// This is a crude statistical test, but it will catch egregious
// mistakes. It will also also fail if any samples contain nan.
//
vks marked this conversation as resolved.
Show resolved Hide resolved
fn check_dirichlet_means(alpha: &Vec<f64>, n: i32, rtol: f64, seed: u64) {
let d = Dirichlet::new(&alpha).unwrap();
let alpha_len = d.alpha.len();
let mut rng = crate::test::rng(seed);
let mut sums = vec![0.0; alpha_len];
for _ in 0..n {
let samples = d.sample(&mut rng);
for i in 0..alpha_len {
sums[i] += samples[i];
}
}
let sample_mean: Vec<f64> = sums.iter().map(|x| x / n as f64).collect();
let alpha_sum: f64 = d.alpha.iter().sum();
let expected_mean: Vec<f64> = d.alpha.iter().map(|x| x / alpha_sum).collect();
for i in 0..alpha_len {
assert_almost_eq!(sample_mean[i], expected_mean[i], rtol);
}
}

#[test]
fn test_dirichlet() {
let d = Dirichlet::new(&[1.0, 2.0, 3.0]).unwrap();
Expand Down Expand Up @@ -172,6 +239,48 @@ mod test {
.collect();
}

#[test]
fn test_dirichlet_means() {
// Check the means of 20000 samples for several different alphas.
let alpha_set = vec![
vec![0.5, 0.25],
vec![123.0, 75.0],
vec![2.0, 2.5, 5.0, 7.0],
vec![0.1, 8.0, 1.0, 2.0, 2.0, 0.85, 0.05, 12.5],
];
let n = 20000;
let rtol = 2e-2;
let seed = 1317624576693539401;
for alpha in alpha_set {
check_dirichlet_means(&alpha, n, rtol, seed);
}
}

#[test]
fn test_dirichlet_means_very_small_alpha() {
// With values of alpha that are all 0.001, check that the means of the
// components of 10000 samples are within 1% of the expected means.
// With the sampling method based on gamma variates, this test would
// fail, with about 10% of the samples containing nan.
let alpha = vec![0.001, 0.001, 0.001];
let n = 10000;
let rtol = 1e-2;
let seed = 1317624576693539401;
check_dirichlet_means(&alpha, n, rtol, seed);
}

#[test]
fn test_dirichlet_means_small_alpha() {
// With values of alpha that are all less than 0.1, check that the
// means of the components of 150000 samples are within 0.1% of the
// expected means.
let alpha = vec![0.05, 0.025, 0.075, 0.05];
let n = 150000;
let rtol = 1e-3;
let seed = 1317624576693539401;
check_dirichlet_means(&alpha, n, rtol, seed);
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Expand Down