-
Notifications
You must be signed in to change notification settings - Fork 86
SVC multiclass #306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: development
Are you sure you want to change the base?
SVC multiclass #306
Conversation
…ry one for all iterations for models with two labels
Had to close the previous pr because I sent it on the wrong branch. |
The issue with Array2 is that the type within the Array is inferred and static. If a Vec<u 32> is passed in, then the application won't be able to transform the labels to: {1, -1} in order for them to be used for binary classification. Originally, you had an assert statement to validate the data passed in had labels: -1 and 1, but now the SVC, with it being multiclass, can accept a wide variety of labels. |
Array2 is an abstraction for a 2D vector, it can be used for any instance supported by Vec, it just need to be implemented. Test Driven Development (TDD) should be followed. Every time you implement something new you have to be sure to support existing behaviour. If there is no test for the operations you are changing, you should add it. Also please add one or more tests when you implement something new. When I try to run the tests in the module I get:
This is a reference implementation as generated by my LLM, you can start from this as it is not fully implemented (ie. it is an example and needs to be implemented using generic types like TX and TY). Please note that you should not modify the existing struct but instead create new structs to handle the multiclass possibility. This implementation suggests to use a 1D Vec but you should check if it correct (if it is right there is no need to use a 2D Vec): To implement a multiclass Support Vector Classification (SVC) in Rust using smartcore, we can adopt the one-vs-one (OvO) strategy, which trains binary classifiers for each pair of classes. Here's a complete implementation: use smartcore::svm::svc::{SVC, SVCParameters};
use smartcore::linalg::{BaseVector, Matrix, MatrixTrait};
use smartcore::metrics::accuracy;
use smartcore::dataset::iris::load_dataset;
// Multiclass SVC using One-vs-One strategy
struct MulticlassSVC {
classifiers: Vec<SVC<f64, DenseMatrix<f64>, Vec<f64>>>,
classes: Vec<u32>,
}
impl MulticlassSVC {
pub fn fit(
x: &DenseMatrix<f64>,
y: &Vec<u32>,
params: &SVCParameters<f64>,
) -> Result<Self, Failed> {
let classes = y.iter().unique().sorted().collect::<Vec<_>>();
let mut classifiers = Vec::new();
// Generate all class pairs
for (i, &class1) in classes.iter().enumerate() {
for &class2 in classes.iter().skip(i + 1) {
// Filter samples for current class pair
let indices: Vec<usize> = y.iter()
.enumerate()
.filter(|(_, &c)| c == class1 || c == class2)
.map(|(i, _)| i)
.collect();
let x_filtered = x.select_rows(&indices);
let y_filtered: Vec<f64> = indices.iter()
.map(|&i| if y[i] == class1 { 1.0 } else { -1.0 })
.collect();
// Train binary classifier
let mut clf = SVC::fit(&x_filtered, &y_filtered, params.clone())?;
classifiers.push((class1, class2, clf));
}
}
Ok(Self { classifiers, classes })
}
pub fn predict(&self, x: &DenseMatrix<f64>) -> Vec<u32> {
let mut votes = vec![HashMap::new(); x.shape().0];
for (class1, class2, clf) in &self.classifiers {
let preds = clf.predict(x).unwrap();
for (i, &p) in preds.iter().enumerate() {
let vote = if p > 0.0 { *class1 } else { *class2 };
*votes[i].entry(vote).or_insert(0) += 1;
}
}
votes.iter()
.map(|v| *v.iter().max_by_key(|(_, &count)| count).unwrap().0)
.collect()
}
}
// Example usage with Iris dataset
fn main() -> Result<(), Failed> {
let iris = load_dataset();
let (x_train, x_test, y_train, y_test) = train_test_split(
&iris.data,
&iris.target,
0.8,
true,
Some(42),
);
let params = SVCParameters::default()
.with_c(200.0)
.with_kernel(smartcore::svm::Kernel::linear());
let clf = MulticlassSVC::fit(&x_train, &y_train, ¶ms)?;
let preds = clf.predict(&x_test);
println!("Accuracy: {}", accuracy(&y_test, &preds));
Ok(())
} Key implementation details:
Advantages over naive implementation:
This implementation follows smartcore's design principles by:
For production use, you'd want to add:
Footnotes |
Appreciate the guide. |
It doesn't look like the DenseMatrix has the select_rows method. |
I assume you want the multiclass struct to take in any generic Array2 object as x. |
Implemented the multiclass feature for SVC using a one for all approach.