From e677b5cd6963ceed96040fe4c2e3b4057a1e0187 Mon Sep 17 00:00:00 2001 From: vaaaaanquish <6syun9@gmail.com> Date: Sat, 16 Jan 2021 02:26:39 +0900 Subject: [PATCH 1/6] add comment --- Cargo.toml | 1 + .../{ => binary_classification}/Cargo.toml | 4 +- .../{ => binary_classification}/src/main.rs | 16 +-- src/booster.rs | 68 +++++++--- src/dataset.rs | 118 ++++++++++++++++-- src/error.rs | 58 +++++++++ src/lib.rs | 8 ++ 7 files changed, 232 insertions(+), 41 deletions(-) rename examples/{ => binary_classification}/Cargo.toml (63%) rename examples/{ => binary_classification}/src/main.rs (67%) diff --git a/Cargo.toml b/Cargo.toml index 39223c0..7256741 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ exclude = [".gitignore", ".gitmodules", "examples", "lightgbm-sys"] [dependencies] lightgbm-sys = "0.1.0" libc = "0.2.81" +derive_builder = "0.5.1" diff --git a/examples/Cargo.toml b/examples/binary_classification/Cargo.toml similarity index 63% rename from examples/Cargo.toml rename to examples/binary_classification/Cargo.toml index b303bc3..8b9b8e2 100644 --- a/examples/Cargo.toml +++ b/examples/binary_classification/Cargo.toml @@ -1,10 +1,10 @@ [package] -name = "lightgbm-example" +name = "lightgbm-example-binary-classification" version = "0.1.0" authors = ["vaaaaanquish <6syun9@gmail.com>"] publish = false [dependencies] -lightgbm = "0.1.1" +lightgbm = { path = "../../" } csv = "1.1.5" itertools = "0.9.0" diff --git a/examples/src/main.rs b/examples/binary_classification/src/main.rs similarity index 67% rename from examples/src/main.rs rename to examples/binary_classification/src/main.rs index fed42c3..9463749 100644 --- a/examples/src/main.rs +++ b/examples/binary_classification/src/main.rs @@ -6,17 +6,7 @@ use itertools::zip; use lightgbm::{Dataset, Booster}; fn main() -> std::io::Result<()> { - // let feature = vec![vec![1.0, 0.1, 0.2, 0.1], - // vec![0.7, 0.4, 0.5, 0.1], - // vec![0.9, 0.8, 0.5, 0.1], - // vec![0.2, 0.2, 0.8, 0.7], - // vec![0.1, 0.7, 1.0, 0.9]]; - // let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; - // let train_dataset = Dataset::from_mat(feature, label).unwrap(); - - // let train_dataset = Dataset::from_file("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()).unwrap(); - - let mut train_rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.train")?; + let mut train_rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.train")?; let mut train_labels: Vec = Vec::new(); let mut train_feature: Vec> = Vec::new(); for result in train_rdr.records() { @@ -28,7 +18,7 @@ fn main() -> std::io::Result<()> { } let train_dataset = Dataset::from_mat(train_feature, train_labels).unwrap(); - let mut rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../lightgbm-sys/lightgbm/examples/binary_classification/binary.test")?; + let mut rdr = csv::ReaderBuilder::new().has_headers(false).delimiter(b'\t').from_path("../../lightgbm-sys/lightgbm/examples/binary_classification/binary.test")?; let mut test_labels: Vec = Vec::new(); let mut test_feature: Vec> = Vec::new(); for result in rdr.records() { @@ -39,7 +29,7 @@ fn main() -> std::io::Result<()> { test_feature.push(feature); } - let booster = Booster::train(train_dataset).unwrap(); + let booster = Booster::train(train_dataset, "objective=binary metric=auc".to_string()).unwrap(); let result = booster.predict(test_feature).unwrap(); let mut tp = 0; diff --git a/src/booster.rs b/src/booster.rs index b8994c7..1dbc4c1 100644 --- a/src/booster.rs +++ b/src/booster.rs @@ -1,42 +1,45 @@ -use lightgbm_sys; - use libc::{c_char, c_double, c_void, c_long}; use std::ffi::CString; use std; -use super::{LGBMResult, Dataset}; + +use lightgbm_sys; + + +use super::{LGBMResult, Dataset, LGBMError}; +/// Core model in LightGBM, containing functions for training, evaluating and predicting. pub struct Booster { pub(super) handle: lightgbm_sys::BoosterHandle } + impl Booster { fn new(handle: lightgbm_sys::BoosterHandle) -> LGBMResult { Ok(Booster{handle}) } - pub fn train(dataset: Dataset) -> LGBMResult { - let params = CString::new("objective=binary metric=auc").unwrap(); + /// Create a new Booster model with given Dataset and parameters. + pub fn train(dataset: Dataset, params: String) -> LGBMResult { + let params = CString::new(params).unwrap(); let mut handle = std::ptr::null_mut(); - unsafe { + lgbm_call!( lightgbm_sys::LGBM_BoosterCreate( dataset.handle, params.as_ptr() as *const c_char, &mut handle - ); - } + ) + )?; - // train let mut is_finished: i32 = 0; - unsafe{ - for _ in 1..100 { - lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished); - } + for _ in 1..100 { + lgbm_call!(lightgbm_sys::LGBM_BoosterUpdateOneIter(handle, &mut is_finished))?; } Ok(Booster::new(handle)?) } + /// Predict results for given data. pub fn predict(&self, data: Vec>) -> LGBMResult> { let data_length = data.len(); let feature_length = data[0].len(); @@ -45,7 +48,7 @@ impl Booster { let out_result: Vec = vec![Default::default(); data.len()]; let flat_data = data.into_iter().flatten().collect::>(); - unsafe { + lgbm_call!( lightgbm_sys::LGBM_BoosterPredictForMat( self.handle, flat_data.as_ptr() as *const c_void, @@ -59,8 +62,41 @@ impl Booster { params.as_ptr() as *const c_char, &mut out_length, out_result.as_ptr() as *mut c_double - ); - } + ) + )?; Ok(out_result) } } + + +impl Drop for Booster { + fn drop(&mut self) { + lgbm_call!(lightgbm_sys::LGBM_BoosterFree(self.handle)).unwrap(); + } +} + + +#[cfg(test)] +mod tests { + use super::*; + fn read_train_file() -> LGBMResult { + Dataset::from_file("lightgbm-sys/lightgbm/examples/binary_classification/binary.train".to_string()) + } + + #[test] + fn predict() { + let dataset = read_train_file().unwrap(); + let bst = Booster::train(dataset, "objective=binary metric=auc".to_string()).unwrap(); + let feature = vec![vec![0.5; 28], vec![0.0; 28], vec![0.9; 28]]; + let result = bst.predict(feature).unwrap(); + let mut normalized_result = Vec::new(); + for r in result{ + if r > 0.5{ + normalized_result.push(1); + } else { + normalized_result.push(0); + } + } + assert_eq!(normalized_result, vec![0, 0, 1]); + } +} diff --git a/src/dataset.rs b/src/dataset.rs index 59449a5..63a4ddc 100644 --- a/src/dataset.rs +++ b/src/dataset.rs @@ -1,21 +1,64 @@ -use libc::{c_void,c_char}; - use std; use std::ffi::CString; +use libc::{c_void,c_char}; use lightgbm_sys; -use super::LGBMResult; +use super::{LGBMResult, LGBMError}; + + +/// Dataset used throughout LightGBM for training. +/// +/// # Examples +/// +/// ## from mat +/// +/// ``` +/// use lightgbm::Dataset; +/// +/// let data = vec![vec![1.0, 0.1, 0.2, 0.1], +/// vec![0.7, 0.4, 0.5, 0.1], +/// vec![0.9, 0.8, 0.5, 0.1], +/// vec![0.2, 0.2, 0.8, 0.7], +/// vec![0.1, 0.7, 1.0, 0.9]]; +/// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; +/// let dataset = Dataset::from_mat(data, label).unwrap(); +/// ``` +/// +/// ## from file +/// +/// ``` +/// use lightgbm::Dataset; +/// +/// let dataset = Dataset::from_file( +/// "lightgbm-sys/lightgbm/examples/binary_classification/binary.train" +/// .to_string()).unwrap(); +/// ``` pub struct Dataset { pub(super) handle: lightgbm_sys::DatasetHandle } + #[link(name = "c")] impl Dataset { fn new(handle: lightgbm_sys::DatasetHandle) -> LGBMResult { Ok(Dataset{handle}) } + /// Create a new `Dataset` from dense array in row-major order. + /// + /// Example + /// ``` + /// use lightgbm::Dataset; + /// + /// let data = vec![vec![1.0, 0.1, 0.2, 0.1], + /// vec![0.7, 0.4, 0.5, 0.1], + /// vec![0.9, 0.8, 0.5, 0.1], + /// vec![0.2, 0.2, 0.8, 0.7], + /// vec![0.1, 0.7, 1.0, 0.9]]; + /// let label = vec![0.0, 0.0, 0.0, 1.0, 1.0]; + /// let dataset = Dataset::from_mat(data, label).unwrap(); + /// ``` pub fn from_mat(data: Vec>, label: Vec) -> LGBMResult { let data_length = data.len(); let feature_length = data[0].len(); @@ -25,7 +68,7 @@ impl Dataset { let mut handle = std::ptr::null_mut(); let flat_data = data.into_iter().flatten().collect::>(); - unsafe{ + lgbm_call!( lightgbm_sys::LGBM_DatasetCreateFromMat( flat_data.as_ptr() as *const c_void, lightgbm_sys::C_API_DTYPE_FLOAT64 as i32, @@ -35,33 +78,88 @@ impl Dataset { params.as_ptr() as *const c_char, reference, &mut handle - ); + ) + )?; + lgbm_call!( lightgbm_sys::LGBM_DatasetSetField( handle, label_str.as_ptr() as *const c_char, label.as_ptr() as *const c_void, data_length as i32, lightgbm_sys::C_API_DTYPE_FLOAT32 as i32 - ); - } + ) + )?; + Ok(Dataset::new(handle)?) } + /// Create a new `Dataset` from file. + /// + /// file is `tsv`. + /// ```text + ///