Skip to content

Commit

Permalink
Fit trait modification and cross validation proposal (#122)
Browse files Browse the repository at this point in the history
* change fit signature
cross valdation POC

* fmt

* fix merge issues

* concat to from_shape_vec

* with labels tests

* Move linfa-pls to new Lapack bound (#3)

* Move linfa-pls to new Lapack bound

* More cleanups

* Playing around with `cross_validation`

* Make generic over dimension

* Run rustfmt

* Add simple test for multi target cv

* Run rustfmt

* Rename cross validation multi target to `cross_validate_multi`

* Run rustfmt

* docs

* update table of contents

* fix pls segmentation fault

* update contribution guide

* snippet

Co-authored-by: Lorenz Schmidt <[email protected]>
  • Loading branch information
Sauro98 and bytesnake authored Apr 28, 2021
1 parent 6866450 commit a5a479f
Show file tree
Hide file tree
Showing 56 changed files with 1,234 additions and 662 deletions.
29 changes: 5 additions & 24 deletions CONTRIBUTE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,19 @@ This document should be used as a reference when contributing to Linfa. It descr

An important part of the Linfa ecosystem is how to organize data for the training and estimation process. A [Dataset](src/dataset/mod.rs) serves this purpose. It is a small wrapper of data and targets types and should be used as argument for the [Fit](src/traits.rs) trait. Its parametrization is generic, with [Records](src/dataset/mod.rs) representing input data (atm only implemented for `ndarray::ArrayBase`) and [Targets](src/dataset/mod.rs) for targets.

You can find traits for different classes of algorithms [here](src/traits.rs). For example, to implement a fittable algorithm, which takes an `Array2` as input data and boolean array as targets:
You can find traits for different classes of algorithms [here](src/traits.rs). For example, to implement a fittable algorithm, which takes an `Array2` as input data and boolean array as targets and could fail with an `Error` struct:
```rust
impl<'a, F: Float> Fit<'a, Array2<F>, Array1<bool>> for SvmParams<F, Pr> {
impl<F: Float> Fit<Array2<F>, Array1<bool>, Error> for SvmParams<F, Pr> {
type Object = Svm<F, Pr>;

fn fit(&self, dataset: &Dataset<Array2<F>, Array1<bool>>) -> Self::Object {
fn fit(&self, dataset: &Dataset<Array2<F>, Array1<bool>>) -> Result<Self::Object, Error> {
...
}
}
```
the type of the dataset is `&Dataset<Kernel<F>, Array1<bool>>`, and lifetime `'a` is the required lifetime for the fitted state. It produces a fitted state, called `Svm<F, Pr>` with probability type `Pr`.
where the type of the input dataset is `&Dataset<Kernel<F>, Array1<bool>>`. It produces a result with a fitted state, called `Svm<F, Pr>` with probability type `Pr`, or an error of type `Error` in case of failure.

The [Predict](src/traits.rs) should be implemented with dataset arguments, as well as arrays. If a dataset is provided, then predict takes its ownership and returns a new dataset with predicted targets. For an array, predict takes a reference and returns predicted targets. In the same context, SVM implemented predict like this:
```rust
impl<F: Float, T: Targets> Predict<Dataset<Array2<F>, T>, Dataset<Array2<F>, Vec<Pr>>>
for Svm<F, Pr>
{
fn predict(&self, data: Dataset<Array2<F>, T>) -> Dataset<Array2<F>, Vec<Pr>> {
...
}
}
```
and
```rust
impl<F: Float, D: Data<Elem = F>> Predict<ArrayBase<D, Ix2>, Vec<Pr>> for Svm<F, Pr> {
fn predict(&self, data: ArrayBase<D, Ix2>) -> Vec<Pr> {
...
}
}
```

For an example of a `Transformer` please look into the [linfa-kernel](linfa-kernel/src/lib.rs) implementation.
The [Predict](src/traits.rs) trait has its own section later in this document, while for an example of a `Transformer` please look into the [linfa-kernel](linfa-kernel/src/lib.rs) implementation.

## Parameters and builder

Expand Down
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ features = ["cblas"]
default-features = false

[dependencies.openblas-src]
version = "0.9.0"
version = "0.10.4"
optional = true
default-features = false
features = ["cblas"]

[dev-dependencies]
ndarray-rand = "0.13"

linfa-datasets = { path = "datasets", features = ["winequality", "iris", "diabetes"] }

[workspace]
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Where does `linfa` stand right now? [Are we learning yet?](http://www.arewelearn
| [ica](algorithms/linfa-ica/) | Independent component analysis | Tested | Unsupervised learning | Contains FastICA implementation |
| [pls](algorithms/linfa-pls/) | Partial Least Squares | Tested | Supervised learning | Contains PLS estimators for dimensionality reduction and regression |
| [tsne](algorithms/linfa-tsne/) | Dimensionality reduction| Tested | Unsupervised learning | Contains exact solution and Barnes-Hut approximation t-SNE |
| [preprocessing](algorithms/linfa-preprocessing/) |Normalization & Vectorization| Tested | Pre-processing | Contains data normalization/whitening and count vectorization/tf-idf |

We believe that only a significant community effort can nurture, build, and sustain a machine learning ecosystem in Rust - there is no other way forward.

Expand Down
10 changes: 5 additions & 5 deletions algorithms/linfa-bayes/src/gaussian_nb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use ndarray::{s, Array1, Array2, ArrayBase, ArrayView1, ArrayView2, Axis, Data,
use ndarray_stats::QuantileExt;
use std::collections::HashMap;

use crate::error::Result;
use crate::error::{BayesError, Result};
use linfa::dataset::{AsTargets, DatasetBase, Labels};
use linfa::traits::{Fit, IncrementalFit, PredictRef};
use linfa::Float;
Expand Down Expand Up @@ -40,13 +40,13 @@ impl GaussianNbParams {
}
}

impl<F, D, L> Fit<'_, ArrayBase<D, Ix2>, L> for GaussianNbParams
impl<F, D, L> Fit<ArrayBase<D, Ix2>, L, BayesError> for GaussianNbParams
where
F: Float,
D: Data<Elem = F>,
L: AsTargets<Elem = usize> + Labels<Elem = usize>,
{
type Object = Result<GaussianNb<F>>;
type Object = GaussianNb<F>;

/// Fit the model
///
Expand Down Expand Up @@ -77,7 +77,7 @@ where
/// # Ok(())
/// # }
/// ```
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, L>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, L>) -> Result<Self::Object> {
// We extract the unique classes in sorted order
let mut unique_classes = dataset.targets.labels();
unique_classes.sort_unstable();
Expand Down Expand Up @@ -303,7 +303,7 @@ where
///
/// __Panics__ if the input is empty or if pairwise orderings are undefined
/// (this occurs in presence of NaN values)
fn predict_ref<'a>(&'a self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
fn predict_ref(&self, x: &ArrayBase<D, Ix2>) -> Array1<usize> {
let joint_log_likelihood = self.joint_log_likelihood(x.view());

// We store the classes and likelihood info in an vec and matrix
Expand Down
2 changes: 1 addition & 1 deletion algorithms/linfa-clustering/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ ndarray-rand = "0.13"
ndarray-stats = "0.4"
num-traits = "0.2"
rand_isaac = "0.3"
thiserror = "1"
partitions = "0.2.4"

linfa = { version = "0.3.1", path = "../..", features = ["ndarray-linalg"] }

[dev-dependencies]
Expand Down
10 changes: 5 additions & 5 deletions algorithms/linfa-clustering/src/appx_dbscan/hyperparameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ impl<F: Float> AppxDbscanHyperParams<F> {
}

fn build(tolerance: F, min_points: usize, slack: F) -> Self {
if tolerance <= F::cast(0.) {
if tolerance <= F::zero() {
panic!("`tolerance` must be greater than 0!");
}
// There is always at least one neighbor to a point (itself)
if min_points <= 1 {
panic!("`min_points` must be greater than 1!");
}

if slack <= F::cast(0.) {
if slack <= F::zero() {
panic!("`slack` must be greater than 0!");
}
Self {
tolerance: tolerance,
min_points: min_points,
slack: slack,
tolerance,
min_points,
slack,
appx_tolerance: tolerance * (F::one() + slack),
}
}
Expand Down
14 changes: 7 additions & 7 deletions algorithms/linfa-clustering/src/gaussian_mixture/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,10 @@ impl<F: Float> GaussianMixtureModel<F> {
reg_covar: F,
) -> Result<(Array1<F>, Array2<F>, Array3<F>)> {
let nk = resp.sum_axis(Axis(0));
if nk.min().unwrap() < &(F::cast(10.) * F::epsilon()) {
if nk.min()? < &(F::cast(10.) * F::epsilon()) {
return Err(GmmError::EmptyCluster(format!(
"Cluster #{} has no more point. Consider decreasing number of clusters or change initialization.",
nk.argmin().unwrap() + 1
nk.argmin()? + 1
)));
}

Expand Down Expand Up @@ -400,12 +400,12 @@ impl<F: Float> GaussianMixtureModel<F> {
}
}

impl<'a, F: Float, R: Rng + SeedableRng + Clone, D: Data<Elem = F>, T> Fit<'a, ArrayBase<D, Ix2>, T>
for GmmHyperParams<F, R>
impl<F: Float, R: Rng + SeedableRng + Clone, D: Data<Elem = F>, T>
Fit<ArrayBase<D, Ix2>, T, GmmError> for GmmHyperParams<F, R>
{
type Object = Result<GaussianMixtureModel<F>>;
type Object = GaussianMixtureModel<F>;

fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
self.validate()?;
let observations = dataset.records().view();
let mut gmm = GaussianMixtureModel::<F>::new(self, dataset, self.rng())?;
Expand Down Expand Up @@ -488,7 +488,7 @@ mod tests {
}
impl MultivariateNormal {
pub fn new(mean: &ArrayView1<f64>, covariance: &ArrayView2<f64>) -> LAResult<Self> {
let lower = covariance.cholesky(UPLO::Lower).unwrap();
let lower = covariance.cholesky(UPLO::Lower)?;
Ok(MultivariateNormal {
mean: mean.to_owned(),
covariance: covariance.to_owned(),
Expand Down
61 changes: 20 additions & 41 deletions algorithms/linfa-clustering/src/gaussian_mixture/errors.rs
Original file line number Diff line number Diff line change
@@ -1,58 +1,37 @@
use crate::k_means::KMeansError;
use ndarray_linalg::error::LinalgError;
use std::error::Error;
use std::fmt::{self, Display};

use thiserror::Error;
pub type Result<T> = std::result::Result<T, GmmError>;

/// An error when modeling a GMM algorithm
#[derive(Debug)]
#[derive(Error, Debug)]
pub enum GmmError {
/// When any of the hyperparameters are set the wrong value
#[error("Invalid value encountered: {0}")]
InvalidValue(String),
/// Errors encountered during linear algebra operations
LinalgError(LinalgError),
#[error(
"Linalg Error: \
Fitting the mixture model failed because some components have \
ill-defined empirical covariance (for instance caused by singleton \
or collapsed samples). Try to decrease the number of components, \
or increase reg_covar. Error: {0}"
)]
LinalgError(#[from] LinalgError),
/// When a cluster has no more data point while fitting GMM
#[error("Fitting failed: {0}")]
EmptyCluster(String),
/// When lower bound computation fails
#[error("Fitting failed: {0}")]
LowerBoundError(String),
/// When fitting EM algorithm does not converge
#[error("Fitting failed: {0}")]
NotConverged(String),
/// When initial KMeans fails
KMeansError(String),
}

impl Display for GmmError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message),
Self::LinalgError(error) => write!(
f,
"Linalg Error: \
Fitting the mixture model failed because some components have \
ill-defined empirical covariance (for instance caused by singleton \
or collapsed samples). Try to decrease the number of components, \
or increase reg_covar. Error: {}",
error
),
Self::EmptyCluster(message) => write!(f, "Fitting failed: {}", message),
Self::LowerBoundError(message) => write!(f, "Fitting failed: {}", message),
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message),
Self::KMeansError(message) => write!(f, "Initial KMeans failed: {}", message),
}
}
}

impl Error for GmmError {}

impl From<LinalgError> for GmmError {
fn from(error: LinalgError) -> GmmError {
GmmError::LinalgError(error)
}
}

impl From<KMeansError> for GmmError {
fn from(error: KMeansError) -> GmmError {
GmmError::KMeansError(error.to_string())
}
#[error("Initial KMeans failed: {0}")]
KMeansError(#[from] KMeansError),
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
#[error(transparent)]
MinMaxError(#[from] ndarray_stats::errors::MinMaxError),
}
8 changes: 4 additions & 4 deletions algorithms/linfa-clustering/src/k_means/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,17 +215,17 @@ impl<F: Float> KMeans<F> {
}
}

impl<'a, F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T> Fit<'a, ArrayBase<D, Ix2>, T>
for KMeansHyperParams<F, R>
impl<F: Float, R: Rng + Clone + SeedableRng, D: Data<Elem = F>, T>
Fit<ArrayBase<D, Ix2>, T, KMeansError> for KMeansHyperParams<F, R>
{
type Object = Result<KMeans<F>>;
type Object = KMeans<F>;

/// Given an input matrix `observations`, with shape `(n_observations, n_features)`,
/// `fit` identifies `n_clusters` centroids based on the training data distribution.
///
/// An instance of `KMeans` is returned.
///
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Self::Object {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let mut rng = self.rng();
let observations = dataset.records().view();
let n_samples = dataset.nsamples();
Expand Down
22 changes: 7 additions & 15 deletions algorithms/linfa-clustering/src/k_means/errors.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
use std::error::Error;
use std::fmt::{self, Display};
use thiserror::Error;

pub type Result<T> = std::result::Result<T, KMeansError>;

/// An error when modeling a KMeans algorithm
#[derive(Debug)]
#[derive(Error, Debug)]
pub enum KMeansError {
/// When any of the hyperparameters are set the wrong value
#[error("Invalid value encountered: {0}")]
InvalidValue(String),
/// When inertia computation fails
#[error("Fitting failed: {0}")]
InertiaError(String),
/// When fitting algorithm does not converge
#[error("Fitting failed: {0}")]
NotConverged(String),
#[error(transparent)]
LinfaError(#[from] linfa::error::Error),
}

impl Display for KMeansError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::InvalidValue(message) => write!(f, "Invalid value encountered: {}", message),
Self::InertiaError(message) => write!(f, "Fitting failed: {}", message),
Self::NotConverged(message) => write!(f, "Fitting failed: {}", message),
}
}
}

impl Error for KMeansError {}
26 changes: 26 additions & 0 deletions algorithms/linfa-elasticnet/examples/elasticnet_cv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use linfa::prelude::*;
use linfa_elasticnet::{ElasticNet, Result};

fn main() -> Result<()> {
// load Diabetes dataset (mutable to allow fast k-folding)
let mut dataset = linfa_datasets::diabetes();

// parameters to compare
let ratios = vec![0.1, 0.2, 0.5, 0.7, 1.0];

// create a model for each parameter
let models = ratios
.iter()
.map(|ratio| ElasticNet::params().penalty(0.3).l1_ratio(*ratio))
.collect::<Vec<_>>();

// get the mean r2 validation score across all folds for each model
let r2_values =
dataset.cross_validate(5, &models, |prediction, truth| prediction.r2(&truth))?;

for (ratio, r2) in ratios.iter().zip(r2_values.iter()) {
println!("L1 ratio: {}, r2 score: {}", ratio, r2);
}

Ok(())
}
6 changes: 3 additions & 3 deletions algorithms/linfa-elasticnet/src/algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use linfa::{

use super::{ElasticNet, ElasticNetParams, Error, Result};

impl<'a, F, D, T> Fit<'a, ArrayBase<D, Ix2>, T> for ElasticNetParams<F>
impl<F, D, T> Fit<ArrayBase<D, Ix2>, T, crate::error::Error> for ElasticNetParams<F>
where
F: Float + Lapack,
D: Data<Elem = F>,
T: AsTargets<Elem = F>,
{
type Object = Result<ElasticNet<F>>;
type Object = ElasticNet<F>;

/// Fit an elastic net model given a feature matrix `x` and a target
/// variable `y`.
Expand All @@ -28,7 +28,7 @@ where
/// Returns a `FittedElasticNet` object which contains the fitted
/// parameters and can be used to `predict` values of the target variable
/// for new feature values.
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<ElasticNet<F>> {
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
self.validate_params()?;
let target = dataset.try_single_target()?;

Expand Down
3 changes: 2 additions & 1 deletion algorithms/linfa-ica/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ ndarray-rand = "0.13"
ndarray-stats = "0.4"
num-traits = "0.2"
rand_isaac = "0.3"
thiserror = "1"

linfa = { version = "0.3.1", path = "../.." }
linfa = { version = "0.3.1", path = "../..", features = ["ndarray-linalg"] }

[dev-dependencies]
ndarray-npy = { version = "0.7", default-features = false }
Expand Down
Loading

0 comments on commit a5a479f

Please sign in to comment.