Skip to content

Commit

Permalink
Merge pull request #8 from vaaaaanquish/add_multi_classification
Browse files Browse the repository at this point in the history
Add multi classification examples
  • Loading branch information
vaaaaanquish authored Jan 16, 2021
2 parents a7b3551 + 573565b commit 361a380
Show file tree
Hide file tree
Showing 13 changed files with 521 additions and 93 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ Cargo.lock
lightgbm-sys/target

# example
examples/target
examples/binary_classification/target/
examples/multiclass_classification/target/
examples/regression/target/
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "lightgbm"
version = "0.1.1"
version = "0.1.2"
authors = ["vaaaaanquish <[email protected]>"]
license = "MIT"
repository = "https://github.com/vaaaaanquish/LightGBM"
Expand All @@ -11,3 +11,5 @@ exclude = [".gitignore", ".gitmodules", "examples", "lightgbm-sys"]
[dependencies]
lightgbm-sys = "0.1.0"
libc = "0.2.81"
derive_builder = "0.5.1"
serde_json = "1.0.59"
11 changes: 11 additions & 0 deletions examples/binary_classification/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "lightgbm-example-binary-classification"
version = "0.1.0"
authors = ["vaaaaanquish <[email protected]>"]
publish = false

[dependencies]
lightgbm = { path = "../../" }
csv = "1.1.5"
itertools = "0.9.0"
serde_json = "1.0.59"
55 changes: 55 additions & 0 deletions examples/binary_classification/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
extern crate lightgbm;
extern crate csv;
extern crate serde_json;
extern crate itertools;


use itertools::zip;
use lightgbm::{Dataset, Booster};
use serde_json::json;


fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}


fn main() -> std::io::Result<()> {
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.train");
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();

let params = json!{
{
"num_iterations": 100,
"objective": "binary",
"metric": "auc"
}
};

let booster = Booster::train(train_dataset, &params).unwrap();
let result = booster.predict(test_features).unwrap();


let mut tp = 0;
for (label, pred) in zip(&test_labels, &result[0]){
if label == &(1 as f32) && pred > &(0.5 as f64) {
tp = tp + 1;
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
tp = tp + 1;
}
println!("{}, {}", label, pred)
}
println!("{} / {}", &tp, result[0].len());
Ok(())
}
11 changes: 11 additions & 0 deletions examples/multiclass_classification/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[package]
name = "lightgbm-example-multiclass-classification"
version = "0.1.0"
authors = ["vaaaaanquish <[email protected]>"]
publish = false

[dependencies]
lightgbm = { path = "../../" }
csv = "1.1.5"
itertools = "0.9.0"
serde_json = "1.0.59"
72 changes: 72 additions & 0 deletions examples/multiclass_classification/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
extern crate lightgbm;
extern crate csv;
extern crate serde_json;
extern crate itertools;


use itertools::zip;
use lightgbm::{Dataset, Booster};
use serde_json::json;


fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}

fn argmax<T: PartialOrd>(xs: &[T]) -> usize {
if xs.len() == 1 {
0
} else {
let mut maxval = &xs[0];
let mut max_ixs: Vec<usize> = vec![0];
for (i, x) in xs.iter().enumerate().skip(1) {
if x > maxval {
maxval = x;
max_ixs = vec![i];
} else if x == maxval {
max_ixs.push(i);
}
}
max_ixs[0]
}
}

fn main() -> std::io::Result<()> {
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.train");
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/multiclass_classification/multiclass.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();

let params = json!{
{
"num_iterations": 100,
"objective": "multiclass",
"metric": "multi_logloss",
"num_class": 5,
}
};

let booster = Booster::train(train_dataset, &params).unwrap();
let result = booster.predict(test_features).unwrap();


let mut tp = 0;
for (label, pred) in zip(&test_labels, &result){
let argmax_pred = argmax(&pred);
if *label == argmax_pred as f32 {
tp = tp + 1;
}
println!("{}, {}, {:?}", label, argmax_pred, &pred);
}
println!("{} / {}", &tp, result.len());
Ok(())
}
5 changes: 3 additions & 2 deletions examples/Cargo.toml → examples/regression/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
[package]
name = "lightgbm-example"
name = "lightgbm-example-regression"
version = "0.1.0"
authors = ["vaaaaanquish <[email protected]>"]
publish = false

[dependencies]
lightgbm = "0.1.1"
lightgbm = { path = "../../" }
csv = "1.1.5"
itertools = "0.9.0"
serde_json = "1.0.59"
55 changes: 55 additions & 0 deletions examples/regression/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
extern crate lightgbm;
extern crate csv;
extern crate serde_json;
extern crate itertools;


use itertools::zip;
use lightgbm::{Dataset, Booster};
use serde_json::json;


fn load_file(file_path: &str) -> (Vec<Vec<f64>>, Vec<f32>) {
let rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path(file_path);
let mut labels: Vec<f32> = Vec::new();
let mut features: Vec<Vec<f64>> = Vec::new();
for result in rdr.unwrap().records() {
let record = result.unwrap();
let label = record[0].parse::<f32>().unwrap();
let feature: Vec<f64> = record.iter().map(|x| x.parse::<f64>().unwrap()).collect::<Vec<f64>>()[1..].to_vec();
labels.push(label);
features.push(feature);
}
(features, labels)
}


fn main() -> std::io::Result<()> {
let (train_features, train_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.train");
let (test_features, test_labels) = load_file("../../lightgbm-sys/lightgbm/examples/regression/regression.test");
let train_dataset = Dataset::from_mat(train_features, train_labels).unwrap();

let params = json!{
{
"num_iterations": 100,
"objective": "regression",
"metric": "l2"
}
};

let booster = Booster::train(train_dataset, &params).unwrap();
let result = booster.predict(test_features).unwrap();


let mut tp = 0;
for (label, pred) in zip(&test_labels, &result[0]){
if label == &(1 as f32) && pred > &(0.5 as f64) {
tp = tp + 1;
} else if label == &(0 as f32) && pred <= &(0.5 as f64) {
tp = tp + 1;
}
println!("{}, {}", label, pred)
}
println!("{} / {}", &tp, result[0].len());
Ok(())
}
56 changes: 0 additions & 56 deletions examples/src/main.rs

This file was deleted.

Loading

0 comments on commit 361a380

Please sign in to comment.