diff --git a/.data/plots/.gitkeep b/.data/plots/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/.gitignore b/.gitignore index ff9a9f4..95567c5 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,9 @@ .data/output/* !.data/output/.gitkeep !.data/output/*.dvc +!.data/plots +.data/plots/* +!.data/plots/.gitkeep # IDE - VSCode .vscode/* diff --git a/Cargo.lock b/Cargo.lock index d5ef05c..3f8ca83 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -175,6 +175,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" + [[package]] name = "bitvec" version = "1.0.1" @@ -223,6 +229,18 @@ version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" +[[package]] +name = "bytemuck" +version = "1.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.5.0" @@ -261,8 +279,10 @@ checksum = "5bc015644b92d5890fab7489e49d21f879d5c990186827d42ec511919404f38b" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-targets 0.52.0", ] @@ -293,6 +313,12 @@ dependencies = [ "half", ] +[[package]] +name = "color_quant" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" + [[package]] name = "colored" version = "2.1.0" @@ -303,6 +329,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "const-cstr" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed3d0b5ff30645a68f35ece8cea4556ca14ef8a1651455f789a099a0513532a6" + [[package]] name = "convert_case" version = "0.4.0" @@ -325,6 +357,42 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "core-graphics" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2581bbab3b8ffc6fcbd550bf46c355135d16e9ff2a6ea032ad6b9bf1d7efe4fb" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-graphics-types", + "foreign-types", + "libc", +] + +[[package]] +name = "core-graphics-types" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "libc", +] + +[[package]] +name = "core-text" +version = "19.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d74ada66e07c1cefa18f8abfba765b486f250de2e4a999e5727fc0dd4b4a25" +dependencies = [ + "core-foundation", + "core-graphics", + "foreign-types", + "libc", +] + [[package]] name = "cpufeatures" version = "0.2.12" @@ -334,6 +402,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-deque" version = "0.8.5" @@ -481,6 +558,48 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs-next" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1" +dependencies = [ + "cfg-if", + "dirs-sys-next", +] + +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "dlib" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" +dependencies = [ + "libloading", +] + +[[package]] +name = "dwrote" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439a1c2ba5611ad3ed731280541d36d2e9c4ac5e7fb818a27b604bdc5a6aa65b" +dependencies = [ + "lazy_static", + "libc", + "winapi", + "wio", +] + [[package]] name = "dyn-clone" version = "1.0.16" @@ -520,18 +639,83 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fdeflate" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f9bfee30e4dedf0ab8b422f03af778d9612b63f502710fc500a334ebe2de645" +dependencies = [ + "simd-adler32", +] + [[package]] name = "finl_unicode" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fcfdc7a0362c9f4444381a9e697c79d435fe65b52a37466fc2c1184cee9edc6" +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "float-ord" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bad48618fdb549078c333a7a8528acb57af271d0433bdecd523eb620628364e" + [[package]] name = "fnv" version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "font-kit" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21fe28504d371085fae9ac7a3450f0b289ab71e07c8e57baa3fb68b9e57d6ce5" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "core-foundation", + "core-graphics", + "core-text", + "dirs-next", + "dwrote", + "float-ord", + "freetype", + "lazy_static", + "libc", + "log", + "pathfinder_geometry", + "pathfinder_simd", + "walkdir", + "winapi", + "yeslogic-fontconfig-sys", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -541,6 +725,27 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "freetype" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efc8599a3078adf8edeb86c71e9f8fa7d88af5ca31e806a867756081f90f5d83" +dependencies = [ + "freetype-sys", + "libc", +] + +[[package]] +name = "freetype-sys" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66ee28c39a43d89fbed8b4798fb4ba56722cfd2b5af81f9326c27614ba88ecd5" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "funty" version = "2.0.0" @@ -659,6 +864,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "gif" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80792593675e051cf94a4b111980da2ba60d4a83e43e0048c5693baab3977045" +dependencies = [ + "color_quant", + "weezl", +] + [[package]] name = "gimli" version = "0.28.1" @@ -866,6 +1081,20 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "image" +version = "0.24.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5690139d2f55868e080017335e4b94cb7414274c74f1669c84fb5feba2c9f69d" +dependencies = [ + "bytemuck", + "byteorder", + "color_quant", + "jpeg-decoder", + "num-traits", + "png", +] + [[package]] name = "indexmap" version = "1.9.3" @@ -931,6 +1160,12 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "jpeg-decoder" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" + [[package]] name = "js-sys" version = "0.3.68" @@ -966,12 +1201,33 @@ version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" +dependencies = [ + "cfg-if", + "windows-sys", +] + [[package]] name = "libm" version = "0.2.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" +[[package]] +name = "libredox" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" +dependencies = [ + "bitflags 2.4.2", + "libc", + "redox_syscall 0.4.1", +] + [[package]] name = "linfa" version = "0.7.0" @@ -1126,6 +1382,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", + "simd-adler32", ] [[package]] @@ -1147,7 +1404,7 @@ checksum = "de59562e5c71656c098d8e966641b31da87b89dc3dcb6e761d3b37dcdfa0cb72" dependencies = [ "async-trait", "base64 0.13.1", - "bitflags", + "bitflags 1.3.2", "bson", "chrono", "derivative", @@ -1420,6 +1677,25 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "pathfinder_geometry" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b7e7b4ea703700ce73ebf128e1450eb69c3a8329199ffbfb9b2a0418e5ad3" +dependencies = [ + "log", + "pathfinder_simd", +] + +[[package]] +name = "pathfinder_simd" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0444332826c70dc47be74a7c6a5fc44e23a7905ad6858d4162b658320455ef93" +dependencies = [ + "rustc_version 0.4.0", +] + [[package]] name = "pbkdf2" version = "0.11.0" @@ -1456,6 +1732,71 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "chrono", + "font-kit", + "image", + "lazy_static", + "num-traits", + "pathfinder_geometry", + "plotters-backend", + "plotters-bitmap", + "plotters-svg", + "ttf-parser", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-bitmap" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cebbe1f70205299abc69e8b295035bb52a6a70ee35474ad10011f0a4efb8543" +dependencies = [ + "gif", + "image", + "plotters-backend", +] + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "png" +version = "0.17.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06e4b0d3d1312775e782c86c91a111aa1f910cbb65e1337f9975b5f9a554b5e1" +dependencies = [ + "bitflags 1.3.2", + "crc32fast", + "fdeflate", + "flate2", + "miniz_oxide", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -1582,7 +1923,7 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -1591,7 +1932,18 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ - "bitflags", + "bitflags 1.3.2", +] + +[[package]] +name = "redox_users" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a18479200779601e498ada4e8c1e1f50e3ee19deb0259c25825a98b5603b2cb4" +dependencies = [ + "getrandom", + "libredox", + "thiserror", ] [[package]] @@ -1783,7 +2135,7 @@ dependencies = [ [[package]] name = "rust-workspace" -version = "0.7.23" +version = "0.8.0" dependencies = [ "cargo-husky", "ciborium", @@ -1797,6 +2149,7 @@ dependencies = [ "mongodb", "ndarray", "octorust", + "plotters", "rand", "regex", "serde_json", @@ -1875,6 +2228,15 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "schemars" version = "0.8.16" @@ -2057,6 +2419,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "simple_asn1" version = "0.6.2" @@ -2218,7 +2586,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ - "bitflags", + "bitflags 1.3.2", "core-foundation", "system-configuration-sys", ] @@ -2463,6 +2831,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "ttf-parser" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff" + [[package]] name = "typed-builder" version = "0.10.0" @@ -2550,6 +2924,16 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -2662,6 +3046,12 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "weezl" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" + [[package]] name = "widestring" version = "1.0.2" @@ -2684,6 +3074,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -2842,6 +3241,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "wio" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d129932f4644ac2396cb456385cbf9e63b5b30c6e8dc4820bdca4eb082037a5" +dependencies = [ + "winapi", +] + [[package]] name = "wyz" version = "0.5.1" @@ -2851,6 +3259,18 @@ dependencies = [ "tap", ] +[[package]] +name = "yeslogic-fontconfig-sys" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2bbd69036d397ebbff671b1b8e4d918610c181c5a16073b96f984a38d08c386" +dependencies = [ + "const-cstr", + "dlib", + "once_cell", + "pkg-config", +] + [[package]] name = "zerocopy" version = "0.7.32" diff --git a/Cargo.toml b/Cargo.toml index 1541b86..6e4badc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust-workspace" -version = "0.7.23" +version = "0.8.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -19,6 +19,7 @@ linfa-trees = { version = "0.7.0", features = ["serde"] } csv = "1.3.0" ndarray = "0.15.6" ciborium = "0.2.2" +plotters = "0.3.5" # GitHub data pipeline octorust = "0.7.0" serde_json = "1.0.113" diff --git a/src/linfa_train/decision_tree/mod.rs b/src/linfa_train/decision_tree/mod.rs index 7c0ea1f..1571220 100644 --- a/src/linfa_train/decision_tree/mod.rs +++ b/src/linfa_train/decision_tree/mod.rs @@ -1,4 +1,4 @@ -//! Logistic regression module. +//! Decision tree module. use ciborium::{cbor, value}; use colored::Colorize; @@ -59,7 +59,7 @@ impl LinfaTrainDecisionTree { InuputArguments { max_depth } } - /// The dataset headers + /// The dataset headers. fn headers(&mut self, reader: &mut Reader) -> Vec { let result = reader .headers() @@ -71,7 +71,7 @@ impl LinfaTrainDecisionTree { result } - /// The dataset data + /// The dataset data. fn data(&mut self, reader: &mut Reader) -> Vec> { let result = reader .records() @@ -90,7 +90,7 @@ impl LinfaTrainDecisionTree { result } - /// The dataset records + /// The dataset records. fn records(&mut self, data: &[Vec], target_index: usize) -> Array2 { let mut records: Vec = vec![]; for record in data.iter() { @@ -110,7 +110,7 @@ impl LinfaTrainDecisionTree { result } - /// The dataset targets + /// The dataset targets. fn targets(&mut self, data: &[Vec], target_index: usize) -> Array1 { let targets = data .iter() @@ -124,7 +124,7 @@ impl LinfaTrainDecisionTree { Array::from(targets) } - /// The dataset + /// The dataset. /// Data source: https:///github.com/plotly/datasets/blob/master/diabetes.csv fn dataset(&mut self) -> Dataset> { let file_path = ".data/input/diabetes.csv"; @@ -138,7 +138,7 @@ impl LinfaTrainDecisionTree { Dataset::new(records, targets).with_feature_names(features) } - /// Trains the model + /// Trains the model. fn train(&mut self, max_depth: usize) { println!("\n{}", "Training the model...".yellow().bold()); let dataset = self.dataset(); @@ -160,7 +160,7 @@ impl LinfaTrainDecisionTree { println!("\n{} {:?}", "Model saved, path:".yellow(), output.as_path()); } - /// Loads the model + /// Loads the model. fn load_model(&mut self) { println!("\n{}", "Testing the model...".yellow().bold()); let dataset = self.dataset(); diff --git a/src/linfa_train/logistic_regression/mod.rs b/src/linfa_train/logistic_regression/mod.rs index b3757d5..c7b522a 100644 --- a/src/linfa_train/logistic_regression/mod.rs +++ b/src/linfa_train/logistic_regression/mod.rs @@ -7,7 +7,12 @@ use linfa::prelude::*; use linfa::Dataset; use linfa_logistic::FittedLogisticRegression; use linfa_logistic::LogisticRegression; +use ndarray::ArrayBase; +use ndarray::Axis; +use ndarray::Dim; +use ndarray::OwnedRepr; use ndarray::{Array, Array1, Array2}; +use plotters::prelude::*; use std::io::Read; use std::path::Path; use std::{env::args, fs, fs::File}; @@ -41,6 +46,7 @@ impl LinfaTrainLogisticRegression { self.train(args.max_iterations); self.load_model(); + self.generate_plots(); } /// Parses arguments passed to the program. @@ -60,7 +66,7 @@ impl LinfaTrainLogisticRegression { InuputArguments { max_iterations } } - /// The dataset headers + /// The dataset headers. fn headers(&mut self, reader: &mut Reader) -> Vec { let result = reader .headers() @@ -72,7 +78,7 @@ impl LinfaTrainLogisticRegression { result } - /// The dataset data + /// The dataset data. fn data(&mut self, reader: &mut Reader) -> Vec> { let result = reader .records() @@ -91,7 +97,7 @@ impl LinfaTrainLogisticRegression { result } - /// The dataset records + /// The dataset records. fn records(&mut self, data: &[Vec], target_index: usize) -> Array2 { let mut records: Vec = vec![]; for record in data.iter() { @@ -111,7 +117,7 @@ impl LinfaTrainLogisticRegression { result } - /// The dataset targets + /// The dataset targets. fn targets(&mut self, data: &[Vec], target_index: usize) -> Array1 { let targets = data .iter() @@ -125,7 +131,7 @@ impl LinfaTrainLogisticRegression { Array::from(targets) } - /// The dataset + /// The dataset. /// Data source: https:///github.com/plotly/datasets/blob/master/diabetes.csv fn dataset(&mut self) -> Dataset> { let file_path = ".data/input/diabetes.csv"; @@ -139,7 +145,7 @@ impl LinfaTrainLogisticRegression { Dataset::new(records, targets).with_feature_names(features) } - /// Trains the model + /// Trains the model. fn train(&mut self, max_iterations: u64) { println!("\n{}", "Training the model...".yellow().bold()); let dataset = self.dataset(); @@ -162,7 +168,7 @@ impl LinfaTrainLogisticRegression { println!("\n{} {:?}", "Model saved, path:".yellow(), output.as_path()); } - /// Loads the model + /// Loads the model. fn load_model(&mut self) { println!("\n{}", "Testing the model...".yellow().bold()); let dataset = self.dataset(); @@ -182,4 +188,105 @@ impl LinfaTrainLogisticRegression { prediction ); } + + /// Generates plots (2D Cartesian coordinate system). + fn generate_plots(&mut self) { + let dataset = self.dataset(); + let records = dataset.records().to_owned(); + let length = records.index_axis(Axis(0), 0).len(); + let x_axis_column_index = length - 1; + for index in 0..x_axis_column_index { + let y_axis_column_index = index; + match self.plot(dataset.clone(), x_axis_column_index, y_axis_column_index) { + Ok(()) => { + println!("Generated a plot {:?}", index); + } + Err(error) => { + panic!("Plot error\n{:?}", error.source()); + } + } + } + } + + /// Generates a plot (2D Cartesian coordinate system). + fn plot( + &mut self, + dataset: Dataset>, + x_axis_column_index: usize, + y_axis_column_index: usize, + ) -> Result<(), Box> { + let records = dataset.records().to_owned(); + let x_axis = records.column(x_axis_column_index); + let x_values = x_axis.to_owned().to_vec(); + let x_range = x_values + .clone() + .into_iter() + .reduce(f32::min) + .unwrap() + .to_owned() + ..x_values + .clone() + .into_iter() + .reduce(f32::max) + .unwrap() + .to_owned(); + let y_axis = records.column(y_axis_column_index); + let y_values = y_axis.to_owned().to_vec(); + let y_range = y_values + .clone() + .into_iter() + .reduce(f32::min) + .unwrap() + .to_owned() + ..y_values + .clone() + .into_iter() + .reduce(f32::max) + .unwrap() + .to_owned(); + + let features = dataset.feature_names(); + let x_feature_default = String::from("x"); + let x_feature = features + .get(x_axis_column_index) + .unwrap_or(&x_feature_default); + let y_feature_default = String::from("y"); + let y_feature = features + .get(y_axis_column_index) + .unwrap_or(&y_feature_default); + let caption = y_feature.to_owned() + " / " + x_feature; + + let file_name = + y_feature.to_owned().to_lowercase() + "-by-" + x_feature.to_lowercase().as_str(); + let plot_path = Path::new(".data") + .join("plots") + .join(file_name + "_diabetes_model.png"); + let root = BitMapBackend::new(&plot_path, (1600, 1200)).into_drawing_area(); + root.fill(&WHITE)?; + + let mut chart = ChartBuilder::on(&root) + .caption(caption, ("sans-serif", 50).into_font()) + .margin(10) + .x_label_area_size(30) + .y_label_area_size(30) + .build_cartesian_2d(x_range, y_range)?; + + let binding = x_values.to_owned().to_vec(); + let plot_data = binding.iter().enumerate().map(|(index, value)| { + ( + value.to_owned(), + y_values.get(index).unwrap_or(&0.0).to_owned(), + ) + }); + + chart.configure_mesh().draw()?; + + chart.draw_series(PointSeries::of_element(plot_data, 2, &RED, &|c, s, st| { + EmptyElement::at(c) + Circle::new((0, 0), s, st.filled()) + }))?; + + root.present()?; + + Ok(()) + } }