Skip to content

Commit 802ba8f

Browse files
committed
rand_distr: add missing value_stability tests
1 parent a43d7f1 commit 802ba8f

File tree

9 files changed

+225
-16
lines changed

9 files changed

+225
-16
lines changed

rand_distr/src/binomial.rs

+18
Original file line numberDiff line numberDiff line change
@@ -326,4 +326,22 @@ mod test {
326326
fn test_binomial_invalid_lambda_neg() {
327327
Binomial::new(20, -10.0).unwrap();
328328
}
329+
330+
#[test]
331+
fn value_stability() {
332+
fn test_samples(n: u64, p: f64, expected: &[u64]) {
333+
let distr = Binomial::new(n, p).unwrap();
334+
let mut rng = crate::test::rng(353);
335+
let mut buf = [0; 4];
336+
for x in &mut buf {
337+
*x = rng.sample(&distr);
338+
}
339+
assert_eq!(buf, expected);
340+
}
341+
342+
// We have multiple code paths: np < 10, p > 0.5
343+
test_samples(2, 0.7, &[1, 1, 2, 1]);
344+
test_samples(20, 0.3, &[7, 7, 5, 7]);
345+
test_samples(2000, 0.6, &[1194, 1208, 1192, 1210]);
346+
}
329347
}

rand_distr/src/cauchy.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ where Standard: Distribution<N>
7272

7373
#[cfg(test)]
7474
mod test {
75-
use crate::Distribution;
76-
use super::Cauchy;
75+
use super::*;
7776

7877
fn median(mut numbers: &mut [f64]) -> f64 {
7978
sort(&mut numbers);
@@ -117,4 +116,22 @@ mod test {
117116
fn test_cauchy_invalid_scale_neg() {
118117
Cauchy::new(0.0, -10.0).unwrap();
119118
}
119+
120+
#[test]
121+
fn value_stability() {
122+
fn test_samples<N: Float + core::fmt::Debug>(m: N, s: N, expected: &[N])
123+
where Standard: Distribution<N> {
124+
let distr = Cauchy::new(m, s).unwrap();
125+
let mut rng = crate::test::rng(353);
126+
let mut buf = [m; 4];
127+
for x in &mut buf {
128+
*x = rng.sample(&distr);
129+
}
130+
assert_eq!(buf, expected);
131+
}
132+
133+
test_samples(100f64, 10.0, &[77.93369152808678, 90.1606912098641,
134+
125.31516221323625, 86.10217834773925]);
135+
test_samples(20f32, 10.0, &[27.175842, -2.0663052, 11.013268, 10.160688]);
136+
}
120137
}

rand_distr/src/dirichlet.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distributi
107107

108108
#[cfg(test)]
109109
mod test {
110-
use super::Dirichlet;
111-
use crate::Distribution;
110+
use super::*;
112111

113112
#[test]
114113
fn test_dirichlet() {
@@ -151,4 +150,14 @@ mod test {
151150
fn test_dirichlet_invalid_alpha() {
152151
Dirichlet::new_with_size(0.0f64, 2).unwrap();
153152
}
153+
154+
#[test]
155+
fn value_stability() {
156+
let mut rng = crate::test::rng(223);
157+
assert_eq!(rng.sample(Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap()),
158+
vec![0.12941567177708177, 0.4702121891675036, 0.4003721390554146]);
159+
assert_eq!(rng.sample(Dirichlet::new_with_size(8.0, 5).unwrap()),
160+
vec![0.17684200044809556, 0.29915953935953055,
161+
0.1832858056608014, 0.1425623503573967, 0.19815030417417595]);
162+
}
154163
}

rand_distr/src/exponential.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ where Exp1: Distribution<N>
121121

122122
#[cfg(test)]
123123
mod test {
124-
use crate::Distribution;
125-
use super::Exp;
124+
use super::*;
126125

127126
#[test]
128127
fn test_exp() {
@@ -142,4 +141,28 @@ mod test {
142141
fn test_exp_invalid_lambda_neg() {
143142
Exp::new(-10.0).unwrap();
144143
}
144+
145+
#[test]
146+
fn value_stability() {
147+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
148+
(distr: D, zero: N, expected: &[N])
149+
{
150+
let mut rng = crate::test::rng(223);
151+
let mut buf = [zero; 4];
152+
for x in &mut buf {
153+
*x = rng.sample(&distr);
154+
}
155+
assert_eq!(buf, expected);
156+
}
157+
158+
test_samples(Exp1, 0f32, &[1.079617, 1.8325565, 0.04601166, 0.34471703]);
159+
test_samples(Exp1, 0f64, &[1.0796170642388276, 1.8325565304274,
160+
0.04601166186842716, 0.3447170217100157]);
161+
162+
test_samples(Exp::new(2.0).unwrap(), 0f32,
163+
&[0.5398085, 0.91627824, 0.02300583, 0.17235851]);
164+
test_samples(Exp::new(1.0).unwrap(), 0f64, &[
165+
1.0796170642388276, 1.8325565304274,
166+
0.04601166186842716, 0.3447170217100157]);
167+
}
145168
}

rand_distr/src/gamma.rs

+57-2
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,7 @@ where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distributi
417417

418418
#[cfg(test)]
419419
mod test {
420-
use crate::Distribution;
421-
use super::{Beta, ChiSquared, StudentT, FisherF};
420+
use super::*;
422421

423422
#[test]
424423
fn test_chi_squared_one() {
@@ -482,4 +481,60 @@ mod test {
482481
fn test_beta_invalid_dof() {
483482
Beta::new(0., 0.).unwrap();
484483
}
484+
485+
#[test]
486+
fn value_stability() {
487+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
488+
(distr: D, zero: N, expected: &[N])
489+
{
490+
let mut rng = crate::test::rng(223);
491+
let mut buf = [zero; 4];
492+
for x in &mut buf {
493+
*x = rng.sample(&distr);
494+
}
495+
assert_eq!(buf, expected);
496+
}
497+
498+
// Gamma has 3 cases: shape == 1, shape < 1, shape > 1
499+
test_samples(Gamma::new(1.0, 5.0).unwrap(), 0f32,
500+
&[5.398085, 9.162783, 0.2300583, 1.7235851]);
501+
test_samples(Gamma::new(0.8, 5.0).unwrap(), 0f32,
502+
&[0.5051203, 0.9048302, 3.095812, 1.8566116]);
503+
test_samples(Gamma::new(1.1, 5.0).unwrap(), 0f64, &[
504+
7.783878094584059, 1.4939528171618057,
505+
8.638017638857592, 3.0949337228829004]);
506+
507+
// ChiSquared has 2 cases: k == 1, k != 1
508+
test_samples(ChiSquared::new(1.0).unwrap(), 0f64, &[
509+
0.4893526200348249, 1.635249736808788,
510+
0.5013580219361969, 0.1457735613733489]);
511+
test_samples(ChiSquared::new(0.1).unwrap(), 0f64, &[
512+
0.014824404726978617, 0.021602123937134326,
513+
0.0000003431429746851693, 0.00000002291755769542258]);
514+
test_samples(ChiSquared::new(10.0).unwrap(), 0f32,
515+
&[12.693656, 6.812016, 11.082001, 12.436167]);
516+
517+
// FisherF has same special cases as ChiSquared on each param
518+
test_samples(FisherF::new(1.0, 13.5).unwrap(), 0f32,
519+
&[0.32283646, 0.048049655, 0.0788893, 1.817178]);
520+
test_samples(FisherF::new(1.0, 1.0).unwrap(), 0f32,
521+
&[0.29925257, 3.4392934, 9.567652, 0.020074]);
522+
test_samples(FisherF::new(0.7, 13.5).unwrap(), 0f64, &[
523+
3.3196593155045124, 0.3409169916262829,
524+
0.03377989856426519, 0.00004041672861036937]);
525+
526+
// StudentT has same special cases as ChiSquared
527+
test_samples(StudentT::new(1.0).unwrap(), 0f32,
528+
&[0.54703987, -1.8545331, 3.093162, -0.14168274]);
529+
test_samples(StudentT::new(1.1).unwrap(), 0f64, &[
530+
0.7729195887949754, 1.2606210611616204,
531+
-1.7553606501113175, -2.377641221169782]);
532+
533+
// Beta has same special cases as Gamma on each param
534+
test_samples(Beta::new(1.0, 0.8).unwrap(), 0f32,
535+
&[0.6444564, 0.357635, 0.4110078, 0.7347192]);
536+
test_samples(Beta::new(0.7, 1.2).unwrap(), 0f64, &[
537+
0.6433129944095513, 0.5373371199711573,
538+
0.10313293199269491, 0.002472280249144378]);
539+
}
485540
}

rand_distr/src/normal.rs

+33-2
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ where StandardNormal: Distribution<N>
185185

186186
#[cfg(test)]
187187
mod tests {
188-
use crate::Distribution;
189-
use super::{Normal, LogNormal};
188+
use super::*;
190189

191190
#[test]
192191
fn test_normal() {
@@ -216,4 +215,36 @@ mod tests {
216215
fn test_log_normal_invalid_sd() {
217216
LogNormal::new(10.0, -1.0).unwrap();
218217
}
218+
219+
#[test]
220+
fn value_stability() {
221+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
222+
(distr: D, zero: N, expected: &[N])
223+
{
224+
let mut rng = crate::test::rng(213);
225+
let mut buf = [zero; 4];
226+
for x in &mut buf {
227+
*x = rng.sample(&distr);
228+
}
229+
assert_eq!(buf, expected);
230+
}
231+
232+
test_samples(StandardNormal, 0f32,
233+
&[-0.11844189, 0.781378, 0.06563994, -1.1932899]);
234+
test_samples(StandardNormal, 0f64, &[
235+
-0.11844188827977231, 0.7813779637772346,
236+
0.06563993969580051, -1.1932899004186373]);
237+
238+
test_samples(Normal::new(0.0, 1.0).unwrap(), 0f32,
239+
&[-0.11844189, 0.781378, 0.06563994, -1.1932899]);
240+
test_samples(Normal::new(2.0, 0.5).unwrap(), 0f64, &[
241+
1.940779055860114, 2.3906889818886174,
242+
2.0328199698479, 1.4033550497906813]);
243+
244+
test_samples(LogNormal::new(0.0, 1.0).unwrap(), 0f32,
245+
&[0.88830346, 2.1844804, 1.0678421, 0.30322206]);
246+
test_samples(LogNormal::new(2.0, 0.5).unwrap(), 0f64, &[
247+
6.964174338639032, 10.921015733601452,
248+
7.6355881556915906, 4.068828213584092]);
249+
}
219250
}

rand_distr/src/pareto.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ where OpenClosed01: Distribution<N>
6666

6767
#[cfg(test)]
6868
mod tests {
69-
use crate::Distribution;
70-
use super::Pareto;
69+
use super::*;
7170

7271
#[test]
7372
#[should_panic]
@@ -86,4 +85,24 @@ mod tests {
8685
assert!(r >= scale);
8786
}
8887
}
88+
89+
#[test]
90+
fn value_stability() {
91+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
92+
(distr: D, zero: N, expected: &[N])
93+
{
94+
let mut rng = crate::test::rng(213);
95+
let mut buf = [zero; 4];
96+
for x in &mut buf {
97+
*x = rng.sample(&distr);
98+
}
99+
assert_eq!(buf, expected);
100+
}
101+
102+
test_samples(Pareto::new(1.0, 1.0).unwrap(), 0f32,
103+
&[1.0423688, 2.1235929, 4.132709, 1.4679428]);
104+
test_samples(Pareto::new(2.0, 0.5).unwrap(), 0f64, &[
105+
9.019295276219136, 4.3097126018270595,
106+
6.837815045397157, 105.8826669383772]);
107+
}
89108
}

rand_distr/src/poisson.rs

+20-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ where Standard: Distribution<N>
134134

135135
#[cfg(test)]
136136
mod test {
137-
use crate::Distribution;
138-
use super::Poisson;
137+
use super::*;
139138

140139
#[test]
141140
fn test_poisson_10() {
@@ -230,4 +229,23 @@ mod test {
230229
fn test_poisson_invalid_lambda_neg() {
231230
Poisson::new(-10.0).unwrap();
232231
}
232+
233+
#[test]
234+
fn value_stability() {
235+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
236+
(distr: D, zero: N, expected: &[N])
237+
{
238+
let mut rng = crate::test::rng(223);
239+
let mut buf = [zero; 4];
240+
for x in &mut buf {
241+
*x = rng.sample(&distr);
242+
}
243+
assert_eq!(buf, expected);
244+
}
245+
246+
// Special cases: < 12, >= 12
247+
test_samples(Poisson::new(7.0).unwrap(), 0f32, &[5.0, 11.0, 6.0, 5.0]);
248+
test_samples(Poisson::new(7.0).unwrap(), 0f64, &[9.0, 5.0, 7.0, 6.0]);
249+
test_samples(Poisson::new(27.0).unwrap(), 0f32, &[28.0, 32.0, 36.0, 36.0]);
250+
}
233251
}

rand_distr/src/weibull.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ where OpenClosed01: Distribution<N>
6363

6464
#[cfg(test)]
6565
mod tests {
66-
use crate::Distribution;
67-
use super::Weibull;
66+
use super::*;
6867

6968
#[test]
7069
#[should_panic]
@@ -83,4 +82,24 @@ mod tests {
8382
assert!(r >= 0.);
8483
}
8584
}
85+
86+
#[test]
87+
fn value_stability() {
88+
fn test_samples<N: Float + core::fmt::Debug, D: Distribution<N>>
89+
(distr: D, zero: N, expected: &[N])
90+
{
91+
let mut rng = crate::test::rng(213);
92+
let mut buf = [zero; 4];
93+
for x in &mut buf {
94+
*x = rng.sample(&distr);
95+
}
96+
assert_eq!(buf, expected);
97+
}
98+
99+
test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32,
100+
&[0.041495778, 0.7531094, 1.4189332, 0.38386202]);
101+
test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[
102+
1.1343478702739669, 0.29470010050655226,
103+
0.7556151370284702, 7.877212340241561]);
104+
}
86105
}

0 commit comments

Comments
 (0)