From 60c6069dc4edd91aaab05b46d35a9030dce04330 Mon Sep 17 00:00:00 2001 From: atomflunder <80397293+atomflunder@users.noreply.github.com> Date: Sat, 8 Jun 2024 18:31:45 +0200 Subject: [PATCH] Add some tests --- src/trueskill/factor_graph.rs | 66 +++++++++++++++++++++++++++++++++++ src/trueskill/gaussian.rs | 34 +++++++++++++++++- src/trueskill/matrix.rs | 30 ++++++++++++++++ src/trueskill/mod.rs | 57 ++++++++++++++++++------------ 4 files changed, 164 insertions(+), 23 deletions(-) diff --git a/src/trueskill/factor_graph.rs b/src/trueskill/factor_graph.rs index adc3ca3..3564391 100644 --- a/src/trueskill/factor_graph.rs +++ b/src/trueskill/factor_graph.rs @@ -274,3 +274,69 @@ impl TruncateFactor { .update_value(self.id, Gaussian::with_pi_tau(pi, tau)) } } + +#[cfg(test)] +mod tests { + use std::f64::INFINITY; + + use super::*; + + #[test] + fn test_delta_inf() { + let mut v1 = Variable::new(); + + v1.set(Gaussian::with_pi_tau(INFINITY, 1.0)); + + assert!(v1.delta(Gaussian::with_pi_tau(0.0, 0.0)) < f64::EPSILON); + } + + #[test] + fn test_sum_factor() { + let mut v1 = Variable::new(); + let mut v2 = Variable::new(); + + v1.set(Gaussian::with_pi_tau(INFINITY, 1.0)); + v2.set(Gaussian::with_pi_tau(0.0, 1.0)); + + let mut sm1 = SumFactor::new( + 0, + Rc::new(RefCell::new(v1.clone())), + vec![Rc::new(RefCell::new(v2.clone()))], + vec![0.0], + ); + + sm1.up(0); + + assert_eq!(sm1.id, 0); + assert_eq!(sm1.coeffs, vec![0.0]); + } + + #[test] + #[should_panic(expected = "no entry found for key")] + fn test_no_update() { + let mut v1 = Variable::new(); + let mut v2 = Variable::new(); + + v1.set(Gaussian::with_pi_tau(INFINITY, 1.0)); + v2.set(Gaussian::with_pi_tau(0.0, 1.0)); + + let mut sm1 = SumFactor::new( + 0, + Rc::new(RefCell::new(v1.clone())), + vec![Rc::new(RefCell::new(v2.clone()))], + vec![0.0], + ); + + sm1.up(0); + + assert_eq!(sm1.id, 0); + assert_eq!(sm1.coeffs, vec![0.0]); + + sm1.update( + &Rc::new(RefCell::new(v1)), + &[Rc::new(RefCell::new(v2))], + &[Gaussian::with_pi_tau(0.0, 1.0)], + &[0.0], + ); + } +} diff --git a/src/trueskill/gaussian.rs b/src/trueskill/gaussian.rs index 8bfb7e5..6024d70 100644 --- a/src/trueskill/gaussian.rs +++ b/src/trueskill/gaussian.rs @@ -8,8 +8,9 @@ pub struct Gaussian { } impl Gaussian { + #[allow(clippy::float_cmp)] pub fn with_mu_sigma(mu: f64, sigma: f64) -> Self { - assert_ne!(sigma, 0.0, "sigma^2 needs to be greater than 0"); + assert_ne!(sigma, 0.0, "Sigma cannot be equal to 0.0"); let pi = sigma.powi(-2); Self { pi, tau: pi * mu } @@ -63,3 +64,34 @@ impl PartialOrd for Gaussian { self.mu().partial_cmp(&other.mu()) } } + +#[cfg(test)] +mod tests { + use std::f64::INFINITY; + + use super::*; + + #[test] + fn test_gaussian_edge_cases() { + let g1 = Gaussian::with_pi_tau(0.0, 0.0); + + assert!(g1.sigma() == INFINITY); + assert!(g1.mu() == 0.0); + } + + #[test] + fn test_gaussian_ordering() { + let g1 = Gaussian::with_mu_sigma(1.0, 1.0); + let g2 = Gaussian::with_mu_sigma(2.0, 1.0); + + assert!(g1 < g2); + } + + #[test] + #[should_panic(expected = "Sigma cannot be equal to 0.0")] + fn test_invalid_sigma() { + let g1 = Gaussian::with_mu_sigma(0.0, 0.0); + + g1.mu(); + } +} diff --git a/src/trueskill/matrix.rs b/src/trueskill/matrix.rs index d6ef084..d83858a 100644 --- a/src/trueskill/matrix.rs +++ b/src/trueskill/matrix.rs @@ -236,3 +236,33 @@ impl std::ops::Add for Matrix { matrix } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matrix_panics() { + use std::panic::catch_unwind; + + let result = catch_unwind(|| Matrix::new(2, 3).determinant()); + assert!(result.is_err()); + + let result = catch_unwind(|| Matrix::new(2, 2).inverse()); + assert!(result.is_err()); + + let result = catch_unwind(|| Matrix::new(2, 2) * Matrix::new(3, 3)); + assert!(result.is_err()); + + let result = catch_unwind(|| Matrix::new(3, 2) + Matrix::new(2, 2)); + assert!(result.is_err()); + + let result = catch_unwind(|| Matrix::new(2, 2) + Matrix::new(2, 3)); + assert!(result.is_err()); + } + + #[test] + fn test_misc() { + assert!(!format!("{:?}", Matrix::new(2, 3)).is_empty()); + } +} diff --git a/src/trueskill/mod.rs b/src/trueskill/mod.rs index ca7b5b1..26ad160 100644 --- a/src/trueskill/mod.rs +++ b/src/trueskill/mod.rs @@ -2109,9 +2109,25 @@ mod tests { &Outcomes::WIN, &TrueSkillConfig::new(), ); + let mp = trueskill_multi_team( + &[ + (&[player_one], MultiTeamOutcome::new(1)), + (&[player_two], MultiTeamOutcome::new(2)), + ], + &TrueSkillConfig::new(), + ); assert_eq!(p1, tp1[0]); assert_eq!(p2, tp2[0]); + + // There is a small difference to be found, since the trueskill_multi_team uses the full algorithm, + // while the 1vs1 and Team vs Team implementations use some shortcuts. + // But they still should yield roughly the same result. + assert!((p1.rating - mp[0][0].rating).abs() < 0.001); + assert!((p2.rating - mp[1][0].rating).abs() < 0.001); + + assert!((p1.uncertainty - mp[0][0].uncertainty).abs() < 0.001); + assert!((p2.uncertainty - mp[1][0].uncertainty).abs() < 0.001); } #[test] @@ -2440,26 +2456,6 @@ mod tests { assert!(v3 == NEG_INFINITY); } - #[test] - fn test_matrix_panics() { - use std::panic::catch_unwind; - - let result = catch_unwind(|| Matrix::new(2, 3).determinant()); - assert!(result.is_err()); - - let result = catch_unwind(|| Matrix::new(2, 2).inverse()); - assert!(result.is_err()); - - let result = catch_unwind(|| Matrix::new(2, 2) * Matrix::new(3, 3)); - assert!(result.is_err()); - - let result = catch_unwind(|| Matrix::new(3, 2) + Matrix::new(2, 2)); - assert!(result.is_err()); - - let result = catch_unwind(|| Matrix::new(2, 2) + Matrix::new(2, 3)); - assert!(result.is_err()); - } - #[test] #[allow(clippy::clone_on_copy)] fn test_misc_stuff() { @@ -2472,8 +2468,6 @@ mod tests { assert!(!format!("{player_one:?}").is_empty()); assert!(!format!("{config:?}").is_empty()); - assert!(!format!("{:?}", Matrix::new(2, 3)).is_empty()); - assert_eq!(player_one, TrueSkillRating::from((25.0, 25.0 / 3.0))); } @@ -2600,4 +2594,23 @@ mod tests { assert!((results[2][0].uncertainty - 4.590_018_525_151_38).abs() < f64::EPSILON); assert!((results[2][1].uncertainty - 1.976_314_792_712_798_2).abs() < f64::EPSILON); } + + #[test] + fn test_multi_teams_empty() { + let res = trueskill_multi_team(&[], &TrueSkillConfig::new()); + + assert!(res.is_empty()); + + let res = trueskill_multi_team( + &[ + (&[TrueSkillRating::new()], MultiTeamOutcome::new(1)), + (&[], MultiTeamOutcome::new(2)), + ], + &TrueSkillConfig::new(), + ); + + assert!(res[1].is_empty()); + assert!((res[0][0].rating - 25.0).abs() < f64::EPSILON); + assert!((res[0][0].uncertainty - 25.0 / 3.0).abs() < f64::EPSILON); + } }