-
-
Notifications
You must be signed in to change notification settings - Fork 255
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fit trait modification and cross validation proposal (#122)
* change fit signature cross valdation POC * fmt * fix merge issues * concat to from_shape_vec * with labels tests * Move linfa-pls to new Lapack bound (#3) * Move linfa-pls to new Lapack bound * More cleanups * Playing around with `cross_validation` * Make generic over dimension * Run rustfmt * Add simple test for multi target cv * Run rustfmt * Rename cross validation multi target to `cross_validate_multi` * Run rustfmt * docs * update table of contents * fix pls segmentation fault * update contribution guide * snippet Co-authored-by: Lorenz Schmidt <[email protected]>
- Loading branch information
Showing
56 changed files
with
1,234 additions
and
662 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 20 additions & 41 deletions
61
algorithms/linfa-clustering/src/gaussian_mixture/errors.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,37 @@ | ||
use crate::k_means::KMeansError; | ||
use ndarray_linalg::error::LinalgError; | ||
use std::error::Error; | ||
use std::fmt::{self, Display}; | ||
|
||
use thiserror::Error; | ||
pub type Result<T> = std::result::Result<T, GmmError>; | ||
|
||
/// An error when modeling a GMM algorithm | ||
#[derive(Debug)] | ||
#[derive(Error, Debug)] | ||
pub enum GmmError { | ||
/// When any of the hyperparameters are set the wrong value | ||
#[error("Invalid value encountered: {0}")] | ||
InvalidValue(String), | ||
/// Errors encountered during linear algebra operations | ||
LinalgError(LinalgError), | ||
#[error( | ||
"Linalg Error: \ | ||
Fitting the mixture model failed because some components have \ | ||
ill-defined empirical covariance (for instance caused by singleton \ | ||
or collapsed samples). Try to decrease the number of components, \ | ||
or increase reg_covar. Error: {0}" | ||
)] | ||
LinalgError(#[from] LinalgError), | ||
/// When a cluster has no more data point while fitting GMM | ||
#[error("Fitting failed: {0}")] | ||
EmptyCluster(String), | ||
/// When lower bound computation fails | ||
#[error("Fitting failed: {0}")] | ||
LowerBoundError(String), | ||
/// When fitting EM algorithm does not converge | ||
#[error("Fitting failed: {0}")] | ||
NotConverged(String), | ||
/// When initial KMeans fails | ||
KMeansError(String), | ||
} | ||
|
||
impl Display for GmmError { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
match self { | ||
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message), | ||
Self::LinalgError(error) => write!( | ||
f, | ||
"Linalg Error: \ | ||
Fitting the mixture model failed because some components have \ | ||
ill-defined empirical covariance (for instance caused by singleton \ | ||
or collapsed samples). Try to decrease the number of components, \ | ||
or increase reg_covar. Error: {}", | ||
error | ||
), | ||
Self::EmptyCluster(message) => write!(f, "Fitting failed: {}", message), | ||
Self::LowerBoundError(message) => write!(f, "Fitting failed: {}", message), | ||
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message), | ||
Self::KMeansError(message) => write!(f, "Initial KMeans failed: {}", message), | ||
} | ||
} | ||
} | ||
|
||
impl Error for GmmError {} | ||
|
||
impl From<LinalgError> for GmmError { | ||
fn from(error: LinalgError) -> GmmError { | ||
GmmError::LinalgError(error) | ||
} | ||
} | ||
|
||
impl From<KMeansError> for GmmError { | ||
fn from(error: KMeansError) -> GmmError { | ||
GmmError::KMeansError(error.to_string()) | ||
} | ||
#[error("Initial KMeans failed: {0}")] | ||
KMeansError(#[from] KMeansError), | ||
#[error(transparent)] | ||
LinfaError(#[from] linfa::error::Error), | ||
#[error(transparent)] | ||
MinMaxError(#[from] ndarray_stats::errors::MinMaxError), | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,19 @@ | ||
use std::error::Error; | ||
use std::fmt::{self, Display}; | ||
use thiserror::Error; | ||
|
||
pub type Result<T> = std::result::Result<T, KMeansError>; | ||
|
||
/// An error when modeling a KMeans algorithm | ||
#[derive(Debug)] | ||
#[derive(Error, Debug)] | ||
pub enum KMeansError { | ||
/// When any of the hyperparameters are set the wrong value | ||
#[error("Invalid value encountered: {0}")] | ||
InvalidValue(String), | ||
/// When inertia computation fails | ||
#[error("Fitting failed: {0}")] | ||
InertiaError(String), | ||
/// When fitting algorithm does not converge | ||
#[error("Fitting failed: {0}")] | ||
NotConverged(String), | ||
#[error(transparent)] | ||
LinfaError(#[from] linfa::error::Error), | ||
} | ||
|
||
impl Display for KMeansError { | ||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
match self { | ||
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message), | ||
Self::InertiaError(message) => write!(f, "Fitting failed: {}", message), | ||
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message), | ||
} | ||
} | ||
} | ||
|
||
impl Error for KMeansError {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
use linfa::prelude::*; | ||
use linfa_elasticnet::{ElasticNet, Result}; | ||
|
||
fn main() -> Result<()> { | ||
// load Diabetes dataset (mutable to allow fast k-folding) | ||
let mut dataset = linfa_datasets::diabetes(); | ||
|
||
// parameters to compare | ||
let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0]; | ||
|
||
// create a model for each parameter | ||
let models = ratios | ||
.iter() | ||
.map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio)) | ||
.collect::<Vec<_>>(); | ||
|
||
// get the mean r2 validation score across all folds for each model | ||
let r2_values = | ||
dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?; | ||
|
||
for (ratio, r2) in ratios.iter().zip(r2_values.iter()) { | ||
println!("L1 ratio: {}, r2 score: {}", ratio, r2); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.