diff --git a/src/svm/svc.rs b/src/svm/svc.rs index cc5a0beb..d72ecdac 100644 --- a/src/svm/svc.rs +++ b/src/svm/svc.rs @@ -58,10 +58,11 @@ //! 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; //! //! let knl = Kernels::linear(); -//! let params = &SVCParameters::default().with_c(200.0).with_kernel(knl); -//! let svc = SVC::fit(&x, &y, params).unwrap(); +//! let parameters = &SVCParameters::default().with_c(200.0).with_kernel(knl); +//! let svc = SVC::fit(&x, &y, parameters).unwrap(); //! //! let y_hat = svc.predict(&x).unwrap(); +//! //! ``` //! //! ## References: @@ -84,12 +85,194 @@ use serde::{Deserialize, Serialize}; use crate::api::{PredictorBorrow, SupervisedEstimatorBorrow}; use crate::error::{Failed, FailedError}; -use crate::linalg::basic::arrays::{Array1, Array2, MutArray}; +use crate::linalg::basic::arrays::{Array, Array1, Array2, MutArray}; use crate::numbers::basenum::Number; use crate::numbers::realnum::RealNumber; use crate::rand_custom::get_rng_impl; use crate::svm::Kernel; +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +#[derive(Debug)] +/// Configuration for a multi-class Support Vector Machine (SVM) classifier. +/// This struct holds the indices of the data points relevant to a specific binary +/// classification problem within a multi-class context, and the two classes +/// being discriminated. +struct MultiClassConfig { + /// The indices of the data points from the original dataset that belong to the two `classes`. + indices: Vec, + /// A tuple representing the two classes that this configuration is designed to distinguish. + classes: (TY, TY), +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + SupervisedEstimatorBorrow<'a, X, Y, SVCParameters> + for MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Creates a new, empty `MultiClassSVC` instance. + fn new() -> Self { + Self { + classifiers: Option::None, + } + } + + /// Fits the `MultiClassSVC` model to the provided data and parameters. + /// + /// This method delegates the fitting process to the inherent `MultiClassSVC::fit` method. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array). + /// * `y` - A reference to the target labels (1D array). + /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training. + /// + /// # Returns + /// A `Result` indicating success (`Self`) or failure (`Failed`). + fn fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + ) -> Result { + MultiClassSVC::fit(x, y, parameters) + } +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + PredictorBorrow<'a, X, TX> for MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Predicts the class labels for new data points. + /// + /// This method delegates the prediction process to the inherent `MultiClassSVC::predict` method. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) for which to make predictions. + /// + /// # Returns + /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error. + fn predict(&self, x: &'a X) -> Result, Failed> { + Ok(self.predict(x).unwrap()) + } +} + +/// A multi-class Support Vector Machine (SVM) classifier. +/// +/// This struct implements a multi-class SVM using the "one-vs-one" strategy, +/// where a separate binary SVC classifier is trained for every pair of classes. +/// +/// # Type Parameters +/// * `'a` - Lifetime parameter for borrowed data. +/// * `TX` - The numeric type of the input features (must implement `Number` and `RealNumber`). +/// * `TY` - The numeric type of the target labels (must implement `Number` and `Ord`). +/// * `X` - The type representing the 2D array of input features (e.g., a matrix). +/// * `Y` - The type representing the 1D array of target labels (e.g., a vector). +pub struct MultiClassSVC< + 'a, + TX: Number + RealNumber, + TY: Number + Ord, + X: Array2, + Y: Array1, +> { + /// An optional vector of binary `SVC` classifiers. + classifiers: Option>>, +} + +impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> + MultiClassSVC<'a, TX, TY, X, Y> +{ + /// Fits the `MultiClassSVC` model to the provided data using a one-vs-one strategy. + /// + /// This method identifies all unique classes in the target labels `y` and then + /// trains a binary `SVC` for every unique pair of classes. For each pair, it + /// extracts the relevant data points and their labels, and then trains a + /// specialized `SVC` for that binary classification task. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array). + /// * `y` - A reference to the target labels (1D array). + /// * `parameters` - A reference to the `SVCParameters` controlling the SVM training for each individual binary classifier. + /// + /// + /// # Returns + /// A `Result` indicating success (`MultiClassSVC`) or failure (`Failed`). + pub fn fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + ) -> Result, Failed> { + let unique_classes = y.unique(); + let mut classifiers = Vec::new(); + // Iterate through all unique pairs of classes (one-vs-one strategy) + for i in 0..unique_classes.len() { + for j in i..unique_classes.len() { + if i == j { + continue; + } + let class0 = unique_classes[j]; + let class1 = unique_classes[i]; + + let mut indices = Vec::new(); + // Collect indices of data points belonging to the current pair of classes + for (index, v) in y.iterator(0).enumerate() { + if *v == class0 || *v == class1 { + indices.push(index) + } + } + let classes = (class0, class1); + let multiclass_config = MultiClassConfig { classes, indices }; + // Fit a binary SVC for the current pair of classes + let svc = SVC::multiclass_fit(x, y, parameters, multiclass_config).unwrap(); + classifiers.push(svc); + } + } + Ok(Self { + classifiers: Some(classifiers), + }) + } + + /// Predicts the class labels for new data points using the trained multi-class SVM. + /// + /// This method uses a "voting" scheme (majority vote) among all the binary + /// classifiers to determine the final prediction for each data point. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) for which to make predictions. + /// + /// # Returns + /// A `Result` containing a `Vec` of predicted class labels (`TX`) or a `Failed` error. + /// + pub fn predict(&self, x: &X) -> Result, Failed> { + // Initialize a HashMap for each data point to store votes for each class + let mut polls = vec![HashMap::new(); x.shape().0]; + // Retrieve the trained binary classifiers + let classifiers = self.classifiers.as_ref().unwrap(); + + // Iterate through each binary classifier + for i in 0..classifiers.len() { + let svc = classifiers.get(i).unwrap(); + let predictions = svc.predict(x).unwrap(); // call SVC::predict for each binary classifier + + // For each prediction from the current binary classifier + for (j, prediction) in predictions.iter().enumerate() { + let prediction = prediction.to_i32().unwrap(); + let poll = polls.get_mut(j).unwrap(); // Get the poll for the current data point + // Increment the vote for the predicted class + if let Some(count) = poll.get_mut(&prediction) { + *count += 1 + } else { + poll.insert(prediction, 1); + } + } + } + + // Determine the final prediction for each data point based on majority vote + Ok(polls + .iter() + .map(|v| { + // Find the class with the maximum votes for each data point + TX::from(*v.iter().max_by_key(|(_, class)| *class).unwrap().0).unwrap() + }) + .collect()) + } +} + #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Debug)] /// SVC Parameters @@ -123,7 +306,7 @@ pub struct SVCParameters, Y: Array1> { - classes: Option>, + classes: Option<(TY, TY)>, instances: Option>>, #[cfg_attr(feature = "serde", serde(skip))] parameters: Option<&'a SVCParameters>, @@ -152,7 +335,9 @@ struct Cache, Y: Array1 struct Optimizer<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1> { x: &'a X, y: &'a Y, + indices: Option>, parameters: &'a SVCParameters, + classes: &'a (TY, TY), svmin: usize, svmax: usize, gmin: TX, @@ -180,12 +365,12 @@ impl, Y: Array1> self.tol = tol; self } + /// The kernel function. pub fn with_kernel(mut self, kernel: K) -> Self { self.kernel = Some(Box::new(kernel)); self } - /// Seed for the pseudo random number generator. pub fn with_seed(mut self, seed: Option) -> Self { self.seed = seed; @@ -241,17 +426,98 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array1 + 'a> SVC<'a, TX, TY, X, Y> { - /// Fits SVC to your data. - /// * `x` - _NxM_ matrix with _N_ observations and _M_ features in each observation. - /// * `y` - class labels - /// * `parameters` - optional parameters, use `Default::default()` to set parameters to default values. + /// Fits a binary Support Vector Classifier (SVC) to the provided data. + /// + /// This is the primary `fit` method for a standalone binary SVC. It expects + /// the target labels `y` to contain exactly two unique classes. If more or + /// fewer than two classes are found, it returns an error. It then extracts + /// these two classes and proceeds to optimize and fit the SVC model. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. `y` must contain exactly two unique class labels. + /// * `parameters` - A reference to the `SVCParameters` controlling the training process. + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. + /// - `Err(Failed)`: If the number of unique classes in `y` is not exactly two, or if the underlying optimization fails. pub fn fit( x: &'a X, y: &'a Y, parameters: &'a SVCParameters, ) -> Result, Failed> { - let (n, _) = x.shape(); + let classes = y.unique(); + // Validate that there are exactly two unique classes in the target labels. + if classes.len() != 2 { + return Err(Failed::fit(&format!( + "Incorrect number of classes: {}. A binary SVC requires exactly two classes.", + classes.len() + ))); + } + let classes = (classes[0], classes[1]); + let svc = Self::optimize_and_fit(x, y, parameters, classes, None); + svc + } + + /// Fits a binary Support Vector Classifier (SVC) specifically for multi-class scenarios. + /// + /// This function is intended to be called by a multi-class strategy (e.g., one-vs-one) + /// to train individual binary SVCs. It takes a `MultiClassConfig` which specifies + /// the two classes this SVC should discriminate and the subset of data indices + /// relevant to these classes. It then delegates the actual optimization and fitting + /// to `optimize_and_fit`. + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// * `parameters` - A reference to the `SVCParameters` controlling the training process (e.g., kernel, C-value, tolerance). + /// * `multiclass_config` - A `MultiClassConfig` struct containing: + /// - `classes`: A tuple `(class0, class1)` specifying the two classes this SVC should distinguish. + /// - `indices`: A `Vec` containing the indices of the data points in `x` and `y that belong to either `class0` or `class1`.` + /// + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new, fitted binary SVC instance. + /// - `Err(Failed)`: If the fitting process encounters an error (e.g., invalid parameters). + fn multiclass_fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + multiclass_config: MultiClassConfig, + ) -> Result, Failed> { + let classes = multiclass_config.classes; + let indices = multiclass_config.indices; + let svc = Self::optimize_and_fit(x, y, parameters, classes, Some(indices)); + svc + } + /// Internal function to optimize and fit the Support Vector Classifier. + /// + /// This is the core logic for training a binary SVC. It performs several checks + /// (e.g., kernel presence, data shape consistency) and then initializes an + /// `Optimizer` to find the support vectors, weights (`w`), and bias (`b`). + /// + /// # Arguments + /// * `x` - A reference to the input features (2D array) of the training data. + /// * `y` - A reference to the target labels (1D array) of the training data. + /// * `parameters` - A reference to the `SVCParameters` defining the SVM model's configuration. + /// * `classes` - A tuple `(class0, class1)` representing the two distinct class labels that the SVC will learn to separate. + /// * `indices` - An `Option>`. If `Some`, it contains the specific indices of data points from `x` and `y` that should be used for training this binary classifier. If `None`, all data points in `x` and `y` are considered. + /// # Returns + /// A `Result` which is: + /// - `Ok(SVC<'a, TX, TY, X, Y>)`: A new `SVC` instance populated with the learned model components (support vectors, weights, bias). + /// - `Err(Failed)`: If any of the validation checks fail (e.g., missing kernel, mismatched data shapes), or if the optimization process fails. + fn optimize_and_fit( + x: &'a X, + y: &'a Y, + parameters: &'a SVCParameters, + classes: (TY, TY), + indices: Option>, + ) -> Result, Failed> { + let (n_samples, _) = x.shape(); + + // Validate that a kernel has been defined in the parameters. if parameters.kernel.is_none() { return Err(Failed::because( FailedError::ParametersError, @@ -259,55 +525,39 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2 + 'a, Y: Array )); } - if n != y.shape() { + // Validate that the number of samples in X matches the number of labels in Y. + if n_samples != y.shape() { return Err(Failed::fit( - "Number of rows of X doesn\'t match number of rows of Y", + "Number of rows of X doesn't match number of rows of Y", )); } - let classes = y.unique(); - - if classes.len() != 2 { - return Err(Failed::fit(&format!( - "Incorrect number of classes: {}", - classes.len() - ))); - } - - // Make sure class labels are either 1 or -1 - for e in y.iterator(0) { - let y_v = e.to_i32().unwrap(); - if y_v != -1 && y_v != 1 { - return Err(Failed::because( - FailedError::ParametersError, - "Class labels must be 1 or -1", - )); - } - } - - let optimizer: Optimizer<'_, TX, TY, X, Y> = Optimizer::new(x, y, parameters); + let optimizer: Optimizer<'_, TX, TY, X, Y> = + Optimizer::new(x, y, indices, parameters, &classes); + // Perform the optimization to find the support vectors, weight vector, and bias. + // This is where the core SVM algorithm (e.g., SMO) would run. let (support_vectors, weight, b) = optimizer.optimize(); + // Construct and return the fitted SVC model. Ok(SVC::<'a> { - classes: Some(classes), - instances: Some(support_vectors), - parameters: Some(parameters), - w: Some(weight), - b: Some(b), - phantomdata: PhantomData, + classes: Some(classes), // Store the two classes the SVC was trained on. + instances: Some(support_vectors), // Store the data points that are support vectors. + parameters: Some(parameters), // Reference to the parameters used for fitting. + w: Some(weight), // The learned weight vector (for linear kernels). + b: Some(b), // The learned bias term. + phantomdata: PhantomData, // Placeholder for type parameters not directly stored. }) } - /// Predicts estimated class labels from `x` /// * `x` - _KxM_ data where _K_ is number of observations and _M_ is number of features. pub fn predict(&self, x: &'a X) -> Result, Failed> { let mut y_hat: Vec = self.decision_function(x)?; for i in 0..y_hat.len() { - let cls_idx = match *y_hat.get(i).unwrap() > TX::zero() { - false => TX::from(self.classes.as_ref().unwrap()[0]).unwrap(), - true => TX::from(self.classes.as_ref().unwrap()[1]).unwrap(), + let cls_idx = match *y_hat.get(i) > TX::zero() { + false => TX::from(self.classes.as_ref().unwrap().0).unwrap(), + true => TX::from(self.classes.as_ref().unwrap().1).unwrap(), }; y_hat.set(i, cls_idx); @@ -445,14 +695,18 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 fn new( x: &'a X, y: &'a Y, + indices: Option>, parameters: &'a SVCParameters, + classes: &'a (TY, TY), ) -> Optimizer<'a, TX, TY, X, Y> { let (n, _) = x.shape(); Optimizer { x, y, + indices, parameters, + classes, svmin: 0, svmax: 0, gmin: ::max_value(), @@ -478,7 +732,12 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 for i in self.permutate(n) { x.clear(); x.extend(self.x.get_row(i).iterator(0).take(n).copied()); - self.process(i, &x, *self.y.get(i), &mut cache); + let y = if *self.y.get(i) == self.classes.1 { + 1 + } else { + -1 + } as f64; + self.process(i, &x, y, &mut cache); loop { self.reprocess(tol, &mut cache); self.find_min_max_gradient(); @@ -514,14 +773,16 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 for i in self.permutate(n) { x.clear(); x.extend(self.x.get_row(i).iterator(0).take(n).copied()); - if *self.y.get(i) == TY::one() && cp < few { - if self.process(i, &x, *self.y.get(i), cache) { + let y = if *self.y.get(i) == self.classes.1 { + 1 + } else { + -1 + } as f64; + if y == 1.0 && cp < few { + if self.process(i, &x, y, cache) { cp += 1; } - } else if *self.y.get(i) == TY::from(-1).unwrap() - && cn < few - && self.process(i, &x, *self.y.get(i), cache) - { + } else if y == -1.0 && cn < few && self.process(i, &x, y, cache) { cn += 1; } @@ -531,14 +792,14 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 } } - fn process(&mut self, i: usize, x: &[TX], y: TY, cache: &mut Cache) -> bool { + fn process(&mut self, i: usize, x: &[TX], y: f64, cache: &mut Cache) -> bool { for j in 0..self.sv.len() { if self.sv[j].index == i { return true; } } - let mut g: f64 = y.to_f64().unwrap(); + let mut g = y; let mut cache_values: Vec<((usize, usize), TX)> = Vec::new(); @@ -559,8 +820,8 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 self.find_min_max_gradient(); if self.gmin < self.gmax - && ((y > TY::zero() && g < self.gmin.to_f64().unwrap()) - || (y < TY::zero() && g > self.gmax.to_f64().unwrap())) + && ((y > 0.0 && g < self.gmin.to_f64().unwrap()) + || (y < 0.0 && g > self.gmax.to_f64().unwrap())) { return false; } @@ -590,7 +851,7 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 ), ); - if y > TY::zero() { + if y > 0.0 { self.smo(None, Some(0), TX::zero(), cache); } else { self.smo(Some(0), None, TX::zero(), cache); @@ -647,7 +908,6 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 let gmin = self.gmin; let mut idxs_to_drop: HashSet = HashSet::new(); - self.sv.retain(|v| { if v.alpha == 0f64 && ((TX::from(v.grad).unwrap() >= gmax && TX::zero() >= TX::from(v.cmax).unwrap()) @@ -666,7 +926,11 @@ impl<'a, TX: Number + RealNumber, TY: Number + Ord, X: Array2, Y: Array1 fn permutate(&self, n: usize) -> Vec { let mut rng = get_rng_impl(self.parameters.seed); - let mut range: Vec = (0..n).collect(); + let mut range = if let Some(indices) = self.indices.clone() { + indices + } else { + (0..n).collect::>() + }; range.shuffle(&mut rng); range } @@ -965,12 +1229,12 @@ mod tests { ]; let knl = Kernels::linear(); - let params = SVCParameters::default() + let parameters = SVCParameters::default() .with_c(200.0) .with_kernel(knl) .with_seed(Some(100)); - let y_hat = SVC::fit(&x, &y, ¶ms) + let y_hat = SVC::fit(&x, &y, ¶meters) .and_then(|lr| lr.predict(&x)) .unwrap(); let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); @@ -1070,6 +1334,56 @@ mod tests { assert!(acc >= 0.9, "accuracy ({acc}) is not larger or equal to 0.9"); } + #[cfg_attr( + all(target_arch = "wasm32", not(target_os = "wasi")), + wasm_bindgen_test::wasm_bindgen_test + )] + #[test] + fn svc_multiclass_fit_predict() { + let x = DenseMatrix::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]) + .unwrap(); + + let y: Vec = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]; + + let knl = Kernels::linear(); + let parameters = SVCParameters::default() + .with_c(200.0) + .with_kernel(knl) + .with_seed(Some(100)); + + let y_hat = MultiClassSVC::fit(&x, &y, ¶meters) + .and_then(|lr| lr.predict(&x)) + .unwrap(); + + let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect())); + + assert!( + acc >= 0.9, + "Multiclass accuracy ({acc}) is not larger or equal to 0.9" + ); + } + #[cfg_attr( all(target_arch = "wasm32", not(target_os = "wasi")), wasm_bindgen_test::wasm_bindgen_test @@ -1106,8 +1420,8 @@ mod tests { ]; let knl = Kernels::linear(); - let params = SVCParameters::default().with_kernel(knl); - let svc = SVC::fit(&x, &y, ¶ms).unwrap(); + let parameters = SVCParameters::default().with_kernel(knl); + let svc = SVC::fit(&x, &y, ¶meters).unwrap(); // serialization let deserialized_svc: SVC<'_, f64, i32, _, _> =