Skip to content

Commit daba88d

Browse files
committed
added tests
1 parent c53d1b6 commit daba88d

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/svm/svc.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,55 @@ mod tests {
10451045
use crate::metrics::accuracy;
10461046
use crate::svm::Kernels;
10471047

1048+
#[cfg_attr(
1049+
all(target_arch = "wasm32", not(target_os = "wasi")),
1050+
wasm_bindgen_test::wasm_bindgen_test
1051+
)]
1052+
#[test]
1053+
fn svc_multiclass_fit_predict() {
1054+
let x = DenseMatrix::from_2d_array(&[
1055+
&[5.1, 3.5, 1.4, 0.2],
1056+
&[4.9, 3.0, 1.4, 0.2],
1057+
&[4.7, 3.2, 1.3, 0.2],
1058+
&[4.6, 3.1, 1.5, 0.2],
1059+
&[5.0, 3.6, 1.4, 0.2],
1060+
&[5.4, 3.9, 1.7, 0.4],
1061+
&[4.6, 3.4, 1.4, 0.3],
1062+
&[5.0, 3.4, 1.5, 0.2],
1063+
&[4.4, 2.9, 1.4, 0.2],
1064+
&[4.9, 3.1, 1.5, 0.1],
1065+
&[7.0, 3.2, 4.7, 1.4],
1066+
&[6.4, 3.2, 4.5, 1.5],
1067+
&[6.9, 3.1, 4.9, 1.5],
1068+
&[5.5, 2.3, 4.0, 1.3],
1069+
&[6.5, 2.8, 4.6, 1.5],
1070+
&[5.7, 2.8, 4.5, 1.3],
1071+
&[6.3, 3.3, 4.7, 1.6],
1072+
&[4.9, 2.4, 3.3, 1.0],
1073+
&[6.6, 2.9, 4.6, 1.3],
1074+
&[5.2, 2.7, 3.9, 1.4],
1075+
])
1076+
.unwrap();
1077+
1078+
let y: Vec<i32> = vec![0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2];
1079+
1080+
let knl = Kernels::linear();
1081+
let parameters = SVCParameters::default()
1082+
.with_c(200.0)
1083+
.with_kernel(knl)
1084+
.with_seed(Some(100));
1085+
1086+
let y_hat = MultiClassSVC::fit(&x, &y, &parameters)
1087+
.and_then(|lr| lr.predict(&x))
1088+
.unwrap();
1089+
1090+
let acc = accuracy(&y, &(y_hat.iter().map(|e| e.to_i32().unwrap()).collect()));
1091+
1092+
assert!(
1093+
acc >= 0.9,
1094+
"Multiclass accuracy ({acc}) is not larger or equal to 0.9"
1095+
);
1096+
}
10481097
#[cfg_attr(
10491098
all(target_arch = "wasm32", not(target_os = "wasi")),
10501099
wasm_bindgen_test::wasm_bindgen_test

0 commit comments

Comments
 (0)