Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): Add Random Forest interface and tests #46

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/ml/blocks-ml-class/src/algorithms.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod linear_regression;
pub mod logistic_regression;
pub mod decision_tree;
pub mod random_forest;
106 changes: 106 additions & 0 deletions crates/ml/blocks-ml-class/src/algorithms/random_forest.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use crate::algorithms::decision_tree::{DecisionTree, DecisionTreeConfig, DecisionTreeError};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use rand::seq::SliceRandom;
use rayon::prelude::*;
use std::collections::HashMap;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum RandomForestError {
#[error("Empty training dataset")]
EmptyTrainingSet,
#[error("Empty test dataset")]
EmptyTestSet,
#[error("Feature dimensions mismatch")]
DimensionMismatch,
#[error("Labels length mismatch with training data")]
LabelsMismatch,
#[error("Invalid number of trees")]
InvalidTreeCount,
#[error("Invalid bootstrap ratio")]
InvalidBootstrapRatio,
#[error("Decision tree error: {0}")]
TreeError(#[from] DecisionTreeError),
}

/// Configuration for Random Forest
#[derive(Debug, Clone)]
pub struct RandomForestConfig {
/// Number of trees in the forest
pub n_trees: usize,
/// Configuration for individual trees
pub tree_config: DecisionTreeConfig,
/// Ratio of samples to use for each tree (bootstrap)
pub bootstrap_ratio: f64,
/// Number of parallel threads to use (None for all available)
pub n_jobs: Option<usize>,
}

impl Default for RandomForestConfig {
fn default() -> Self {
Self {
n_trees: 100,
tree_config: DecisionTreeConfig::default(),
bootstrap_ratio: 0.7,
n_jobs: None,
}
}
}

/// Random Forest implementation
#[derive(Debug)]
pub struct RandomForest {
config: RandomForestConfig,
trees: Vec<DecisionTree>,
feature_importances: Option<Array1<f64>>,
oob_score: Option<f64>,
}

impl RandomForest {
/// Creates a new RandomForest instance with the given configuration
pub fn new(config: RandomForestConfig) -> Result<Self, RandomForestError> {
if config.n_trees == 0 {
return Err(RandomForestError::InvalidTreeCount);
}
if config.bootstrap_ratio <= 0.0 || config.bootstrap_ratio > 1.0 {
return Err(RandomForestError::InvalidBootstrapRatio);
}

Ok(Self {
config,
trees: Vec::new(),
feature_importances: None,
oob_score: None,
})
}

/// Fits the random forest to the training data
pub fn fit(&mut self, x: ArrayView2<f64>, y: ArrayView1<f64>) -> Result<(), RandomForestError> {
unimplemented!()
}

/// Predicts class labels for new data points
pub fn predict(&self, x: ArrayView2<f64>) -> Result<Array1<f64>, RandomForestError> {
unimplemented!()
}

/// Predicts class probabilities for new data points
pub fn predict_proba(&self, x: ArrayView2<f64>) -> Result<Array2<f64>, RandomForestError> {
unimplemented!()
}

/// Returns feature importances if the forest is fitted
pub fn feature_importances(&self) -> Option<&Array1<f64>> {
self.feature_importances.as_ref()
}

/// Returns the out-of-bag score if available
pub fn oob_score(&self) -> Option<f64> {
self.oob_score
}

/// Returns the number of trees in the forest
pub fn n_trees(&self) -> usize {
self.trees.len()
}
}
258 changes: 258 additions & 0 deletions crates/ml/blocks-ml-class/src/algorithms/random_forest_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::{arr1, arr2};

#[test]
fn test_random_forest_new() {
let config = RandomForestConfig::default();
let forest = RandomForest::new(config).unwrap();
assert_eq!(forest.n_trees(), 0);
assert!(forest.feature_importances().is_none());
assert!(forest.oob_score().is_none());
}

#[test]
fn test_random_forest_new_invalid_config() {
// Test invalid tree count
let config = RandomForestConfig {
n_trees: 0,
..Default::default()
};
assert!(matches!(
RandomForest::new(config),
Err(RandomForestError::InvalidTreeCount)
));

// Test invalid bootstrap ratio
let config = RandomForestConfig {
bootstrap_ratio: 0.0,
..Default::default()
};
assert!(matches!(
RandomForest::new(config),
Err(RandomForestError::InvalidBootstrapRatio)
));

let config = RandomForestConfig {
bootstrap_ratio: 1.1,
..Default::default()
};
assert!(matches!(
RandomForest::new(config),
Err(RandomForestError::InvalidBootstrapRatio)
));
}

#[test]
fn test_random_forest_fit_empty_dataset() {
let mut forest = RandomForest::new(RandomForestConfig::default()).unwrap();
let x = Array2::<f64>::zeros((0, 2));
let y = Array1::<f64>::zeros(0);
assert!(matches!(
forest.fit(x.view(), y.view()),
Err(RandomForestError::EmptyTrainingSet)
));
}

#[test]
fn test_random_forest_fit_labels_mismatch() {
let mut forest = RandomForest::new(RandomForestConfig::default()).unwrap();
let x = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let y = arr1(&[1.0]);
assert!(matches!(
forest.fit(x.view(), y.view()),
Err(RandomForestError::LabelsMismatch)
));
}

#[test]
fn test_random_forest_predict_without_fit() {
let forest = RandomForest::new(RandomForestConfig::default()).unwrap();
let x = arr2(&[[1.0, 2.0]]);
assert!(matches!(
forest.predict(x.view()),
Err(RandomForestError::EmptyTrainingSet)
));
}

#[test]
fn test_random_forest_predict_dimension_mismatch() {
let mut forest = RandomForest::new(RandomForestConfig::default()).unwrap();
let x_train = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
let y_train = arr1(&[0.0, 1.0]);
forest.fit(x_train.view(), y_train.view()).unwrap();

let x_test = arr2(&[[1.0], [2.0]]);
assert!(matches!(
forest.predict(x_test.view()),
Err(RandomForestError::DimensionMismatch)
));
}

#[test]
fn test_random_forest_binary_classification() {
let mut forest = RandomForest::new(RandomForestConfig {
n_trees: 10,
bootstrap_ratio: 0.7,
..Default::default()
})
.unwrap();

// Simple binary classification dataset
let x = arr2(&[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[0.1, 0.1],
[0.9, 0.9],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0, 0.0, 1.0]);

forest.fit(x.view(), y.view()).unwrap();

// Test predictions
let predictions = forest.predict(x.view()).unwrap();
assert_eq!(predictions.len(), x.nrows());
assert!(predictions.iter().all(|&p| p == 0.0 || p == 1.0));

// Test probabilities
let probas = forest.predict_proba(x.view()).unwrap();
assert_eq!(probas.shape(), &[x.nrows(), 2]);
assert!(probas.iter().all(|&p| p >= 0.0 && p <= 1.0));

// Each row should sum to approximately 1
for row in probas.rows() {
assert_relative_eq!(row.sum(), 1.0, epsilon = 1e-10);
}

// Test feature importances
let importances = forest.feature_importances().unwrap();
assert_eq!(importances.len(), 2);
assert!(importances.iter().all(|&x| x >= 0.0 && x <= 1.0));
assert_relative_eq!(importances.sum(), 1.0, epsilon = 1e-10);

// Test OOB score
assert!(forest.oob_score().is_some());
assert!(forest.oob_score().unwrap() >= 0.0 && forest.oob_score().unwrap() <= 1.0);
}

#[test]
fn test_random_forest_multiclass() {
let mut forest = RandomForest::new(RandomForestConfig {
n_trees: 10,
bootstrap_ratio: 0.7,
..Default::default()
})
.unwrap();

// Multiclass dataset
let x = arr2(&[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[0.5, 0.5],
[0.5, 0.6],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);

forest.fit(x.view(), y.view()).unwrap();

// Test predictions
let predictions = forest.predict(x.view()).unwrap();
assert_eq!(predictions.len(), x.nrows());
assert!(predictions.iter().all(|&p| p == 0.0 || p == 1.0 || p == 2.0));

// Test probabilities
let probas = forest.predict_proba(x.view()).unwrap();
assert_eq!(probas.shape(), &[x.nrows(), 3]);
assert!(probas.iter().all(|&p| p >= 0.0 && p <= 1.0));

// Each row should sum to approximately 1
for row in probas.rows() {
assert_relative_eq!(row.sum(), 1.0, epsilon = 1e-10);
}
}

#[test]
fn test_random_forest_parallel_training() {
let mut forest_single = RandomForest::new(RandomForestConfig {
n_trees: 10,
n_jobs: Some(1),
..Default::default()
})
.unwrap();

let mut forest_parallel = RandomForest::new(RandomForestConfig {
n_trees: 10,
n_jobs: Some(4),
..Default::default()
})
.unwrap();

let x = arr2(&[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0]);

// Both should train successfully
forest_single.fit(x.view(), y.view()).unwrap();
forest_parallel.fit(x.view(), y.view()).unwrap();

assert_eq!(forest_single.n_trees(), forest_parallel.n_trees());
}

#[test]
fn test_random_forest_bootstrap_sampling() {
let mut forest = RandomForest::new(RandomForestConfig {
n_trees: 5,
bootstrap_ratio: 0.5, // Small ratio to ensure different samples
..Default::default()
})
.unwrap();

let x = arr2(&[
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0],
[0.5, 0.5],
[0.5, 0.6],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0, 2.0, 2.0]);

forest.fit(x.view(), y.view()).unwrap();

// Each tree should have seen different samples
assert!(forest.oob_score().is_some());
}

#[test]
fn test_random_forest_feature_importance_consistency() {
let mut forest = RandomForest::new(RandomForestConfig {
n_trees: 10,
..Default::default()
})
.unwrap();

// Dataset where first feature is more important
let x = arr2(&[
[0.0, 0.5],
[0.1, 0.4],
[0.9, 0.6],
[1.0, 0.5],
]);
let y = arr1(&[0.0, 0.0, 1.0, 1.0]);

forest.fit(x.view(), y.view()).unwrap();

let importances = forest.feature_importances().unwrap();
assert!(importances[0] > importances[1]); // First feature should be more important
}
}
6 changes: 6 additions & 0 deletions crates/ml/blocks-ml-class/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub mod algorithms;

pub use algorithms::linear_regression;
pub use algorithms::logistic_regression;
pub use algorithms::decision_tree;
pub use algorithms::random_forest;