diff --git a/Cargo.lock b/Cargo.lock index 0a9a53b..71abb4e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + [[package]] name = "atomic_float" version = "1.1.0" @@ -10,9 +19,9 @@ checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "cfg-if" @@ -53,9 +62,9 @@ checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "genetic-rs" -version = "0.5.4" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a68bb62a836f6ea3261d77cfec4012316e206f53e7d0eab519f5f3630e86001f" +checksum = "372d080448bae68a4a8963e6acadd81621510cdf535c8eb5ecc39ab605a17e88" dependencies = [ "genetic-rs-common", "genetic-rs-macros", @@ -63,20 +72,21 @@ dependencies = [ [[package]] name = "genetic-rs-common" -version = "0.5.4" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be7aaffd4e4dc82d11819d40794f089c37d02595a401f229ed2877d1a4c401d" +checksum = "94a87c5bbc9d445ab0684eb5109b5781578c02a63f8ed2d286ca75b94848f43f" dependencies = [ "rand", "rayon", "replace_with", + "tracing", ] [[package]] name = "genetic-rs-macros" -version = "0.5.4" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e73b1f36ea3e799232e1a3141a2765fa6ee9ed7bb3fed96ccfb3bf272d1832e" +checksum = "f5d928bc6dae6aef04ff1156a4555d3313f5a6cf607235b5931c62710d066c5a" dependencies = [ "genetic-rs-common", "proc-macro2", @@ -86,12 +96,13 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.12" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" dependencies = [ "cfg-if", "libc", + "r-efi", "wasi", ] @@ -113,6 +124,21 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "log" +version = "0.4.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" + +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "memchr" version = "2.7.4" @@ -132,8 +158,38 @@ dependencies = [ "serde", "serde-big-array", "serde_json", + "tracing", + "tracing-subscriber", +] + +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", ] +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -158,22 +214,28 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "rand" -version = "0.8.5" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" dependencies = [ - "libc", "rand_chacha", "rand_core", + "zerocopy", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", "rand_core", @@ -181,9 +243,9 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ "getrandom", ] @@ -208,6 +270,50 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "replace_with" version = "0.1.7" @@ -222,9 +328,9 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] @@ -240,9 +346,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", @@ -251,9 +357,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "itoa", "memchr", @@ -261,6 +367,21 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "smallvec" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" + [[package]] name = "syn" version = "2.0.89" @@ -272,14 +393,145 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.14.2+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" +dependencies = [ + "wit-bindgen-rt", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "wit-bindgen-rt" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" +dependencies = [ + "bitflags", +] + +[[package]] +name = "zerocopy" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 4b26e0f..0233873 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,17 +20,19 @@ rustdoc-args = ["--cfg", "docsrs"] [features] default = [] serde = ["dep:serde", "dep:serde-big-array"] - +tracing = ["dep:tracing", "genetic-rs/tracing"] [dependencies] atomic_float = "1.1.0" -bitflags = "2.8.0" -genetic-rs = { version = "0.5.4", features = ["rayon", "derive"] } +bitflags = "2.9.0" +genetic-rs = { version = "0.6.0", features = ["rayon", "derive"] } lazy_static = "1.5.0" rayon = "1.10.0" replace_with = "0.1.7" -serde = { version = "1.0.217", features = ["derive"], optional = true } +serde = { version = "1.0.219", features = ["derive"], optional = true } serde-big-array = { version = "0.5.1", optional = true } +tracing = { version = "0.1.41", optional = true } [dev-dependencies] -serde_json = "1.0.138" \ No newline at end of file +serde_json = "1.0.140" +tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } \ No newline at end of file diff --git a/README.md b/README.md index 4e9828b..c57bf49 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,95 @@ Implementation of the NEAT algorithm using `genetic-rs`. *Do you like this crate and want to support it? If so, leave a ⭐* -# How To Use -TODO +# Guide +A neural network has two const generic parameters, `I` and `O`. `I` represents the number of input neurons and `O` represents the number of output neurons. To create a neural network, use `NeuralNetwork::new`: + +```rust +use neat::*; + +let mut rng = rand::thread_rng(); + +// creates a randomized neural network with 3 input neurons and 2 output neurons. +let net: NeuralNetwork<3, 2> = NeuralNetwork::new(MutationSettings::default(), &mut rng); +``` + +Once you have a neural network, you can use it to predict things: + +```rust +let prediction = net.predict([1, 2, 3]); +dbg!(prediction); +``` + +A completely random neural network isn't quite useful, however, so you must run a simulation to train and perfect these networks. Let's look at the following code: + +```rust +use neat::*; + +// derive some traits so that we can use this agent with `genetic-rs`. +#[derive(Debug, Clone, Prunable, CrossoverReproduction, RandomlyMutable)] +struct MyAgentGenome { + brain: NeuralNetwork<3, 2> +} + +impl GenerateRandom for MyAgentGenome { + // allows us to use `Vec::gen_random` for the initial population. + fn gen_random(rng: &mut impl rand::Rng) -> Self { + Self(NeuralNetwork::new(MutationSettings::default(), rng)) + } +} + +// creates a bigger number within the bounds of 0 to 1 as `actual` approaches `expected`. +fn inverse_error(expected: f32, actual: f32) -> f32 { + 1.0 / (1.0 + (expected - actual).abs()) +} + +fn fitness(agent: &MyAgentGenome) -> f32 { + let mut rng = rand::thread_rng(); + let mut fit = 0; + + for _ in 0..10 { + // run the test multiple times for consistency + + let inputs = [rng.gen(), rng.gen(), rng.gen()]; + + // try to force the network to learn to do some basic logic + let expected0 = (inputs[0] >= 0.5 && inputs[1] < 0.5) as f32; + let expected1 = (inputs[2] >= 0.5) as f32; + + let output = agent.brain.predict(inputs); + + fit += inverse_error(expected0, output[0]); + fit += inverse_error(expected1, output[1]); + } + + fit +} + +fn main() { + let mut sim = GeneticSim::new( + // create a population of 100 random neural networks + Vec::gen_random(100), + + // provide the fitness function that will + // test the agents individually so the nextgen + // function can eliminate the weaker ones. + fitness, + + // this nextgen function will kill/drop agents + // that don't have a high enough fitness, and repopulate + // by performing crossover reproduction between the remaining ones + crossover_pruning_nextgen, + ); + + // fast forward 100 generations. identical to looping + // 100 times with `sim.next_generation()`. + sim.perform_generations(100); +} +``` + +The struct `MyAgentGenome` is created to wrap the `NeuralNetwork` and functions as the overall hereditary data of an agent. In a more complex scenario, you could add more `genetic-rs`-compatible types to store other hereditary information, such as an agent's size or speed. + +Check out the [examples](https://github.com/HyperCodec/neat/tree/main/examples) for more usecases. ### License This crate falls under the `MIT` license diff --git a/examples/readme_ex.rs b/examples/readme_ex.rs new file mode 100644 index 0000000..150d5b6 --- /dev/null +++ b/examples/readme_ex.rs @@ -0,0 +1,79 @@ +use neat::*; + +#[cfg(feature = "tracing")] +use tracing_subscriber::EnvFilter; + +// derive some traits so that we can use this agent with `genetic-rs`. +#[derive(Debug, Clone, PartialEq, CrossoverReproduction, DivisionReproduction, RandomlyMutable)] +struct MyAgentGenome { + brain: NeuralNetwork<3, 2>, +} + +impl Prunable for MyAgentGenome {} + +impl GenerateRandom for MyAgentGenome { + // allows us to use `Vec::gen_random` for the initial population. + fn gen_random(rng: &mut impl Rng) -> Self { + Self { + brain: NeuralNetwork::new(MutationSettings::default(), rng), + } + } +} + +// creates a bigger number within the bounds of 0 to 1 as `actual` approaches `expected`. +fn inverse_error(expected: f32, actual: f32) -> f32 { + 1.0 / (1.0 + (expected - actual).abs()) +} + +fn fitness(agent: &MyAgentGenome) -> f32 { + let mut rng = rand::rng(); + let mut fit = 0.; + + for _ in 0..10 { + // run the test multiple times for consistency + + let inputs = [rng.random(), rng.random(), rng.random()]; + + // try to force the network to learn to do some basic logic + let expected0: f32 = (inputs[0] >= 0.5 && inputs[1] < 0.5).into(); + let expected1: f32 = (inputs[2] >= 0.5).into(); + + // println!("predicting {i}"); + let output = agent.brain.predict(inputs); + + fit += inverse_error(expected0, output[0]); + fit += inverse_error(expected1, output[1]); + } + + fit +} + +fn main() { + #[cfg(feature = "tracing")] + tracing_subscriber::fmt() + .with_env_filter( + EnvFilter::try_from_default_env() + .unwrap_or(EnvFilter::from("DEBUG")) + ) + .init(); + + let mut sim = GeneticSim::new( + // create a population of 100 random neural networks + Vec::gen_random(2), + // provide the fitness function that will + // test the agents individually so the nextgen + // function can eliminate the weaker ones. + fitness, + // this nextgen function will kill/drop agents + // that don't have a high enough fitness, and repopulate + // by performing crossover reproduction between the remaining ones + division_pruning_nextgen, + ); + + // fast forward 100 generations. identical to looping + // 100 times with `sim.next_generation()`. + for i in 0..100000 { + println!("{i}"); + sim.next_generation(); + } +} diff --git a/src/activation.rs b/src/activation.rs index af9f74e..3044061 100644 --- a/src/activation.rs +++ b/src/activation.rs @@ -82,7 +82,7 @@ impl ActivationRegistry { let acts = self.activations(); acts.into_iter() - .filter(|a| !a.scope.contains(NeuronScope::NONE) && a.scope.contains(scope)) + .filter(|a| a.scope.contains(scope)) .collect() } } @@ -105,6 +105,16 @@ impl Default for ActivationRegistry { } } +impl fmt::Debug for ActivationRegistry { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let keys: Vec<_> = self.fns.keys().collect(); + + f.debug_struct("ActivationRegistry") + .field("fns", &keys) + .finish() + } +} + /// A trait that represents an activation method. pub trait Activation { /// The activation function. diff --git a/src/neuralnet.rs b/src/neuralnet.rs index cce0d61..b87deec 100644 --- a/src/neuralnet.rs +++ b/src/neuralnet.rs @@ -1,14 +1,13 @@ use std::{ collections::HashSet, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, + atomic::{AtomicUsize, Ordering}, Arc, }, }; use atomic_float::AtomicF32; use genetic_rs::prelude::*; -use rand::Rng; use replace_with::replace_with_or_abort; use crate::{ @@ -24,6 +23,9 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "serde")] use serde_big_array::BigArray; +#[cfg(feature = "tracing")] +use tracing::*; + /// The mutation settings for [`NeuralNetwork`]. /// Does not affect [`NeuralNetwork::mutate`], only [`NeuralNetwork::divide`] and [`NeuralNetwork::crossover`]. #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] @@ -42,7 +44,7 @@ pub struct MutationSettings { impl Default for MutationSettings { fn default() -> Self { Self { - mutation_rate: 0.01, + mutation_rate: 0.05, mutation_passes: 3, weight_mutation_amount: 0.5, } @@ -70,11 +72,15 @@ pub struct NeuralNetwork { /// The mutation settings for the network. pub mutation_settings: MutationSettings, + + /// The total number of connections in the network + pub(crate) total_connections: usize, } impl NeuralNetwork { // TODO option to set default output layer activations /// Creates a new random neural network with the given settings. + #[cfg_attr(feature = "tracing", instrument)] pub fn new(mutation_settings: MutationSettings, rng: &mut impl Rng) -> Self { let mut output_layer = Vec::with_capacity(O); @@ -86,21 +92,25 @@ impl NeuralNetwork { )); } + let mut total_connections = 0; let mut input_layer = Vec::with_capacity(I); for _ in 0..I { let mut already_chosen = Vec::new(); - let outputs = (0..rng.gen_range(1..=O)) + let conns = rng.random_range(1..=O); + total_connections += conns; + + let outputs = (0..conns) .map(|_| { - let mut j = rng.gen_range(0..O); + let mut j = rng.random_range(0..O); while already_chosen.contains(&j) { - j = rng.gen_range(0..O); + j = rng.random_range(0..O); } output_layer[j].input_count += 1; already_chosen.push(j); - (NeuronLocation::Output(j), rng.gen()) + (NeuronLocation::Output(j), rng.random()) }) .collect(); @@ -119,10 +129,12 @@ impl NeuralNetwork { hidden_layers: vec![], output_layer, mutation_settings, + total_connections, } } /// Runs the neural network, propagating values from input to output layer. + #[cfg_attr(feature = "tracing", instrument)] pub fn predict(&self, inputs: [f32; I]) -> [f32; O] { let cache = Arc::new(NeuralNetCache::from(self)); cache.prime_inputs(inputs); @@ -137,19 +149,14 @@ impl NeuralNetwork { fn eval(&self, loc: impl AsRef, cache: Arc>) { let loc = loc.as_ref(); - if !cache.claim(loc) { - // some other thread is already - // waiting to do this task, currently doing it, or done. - // no need to do it again. + if !cache.is_ready(loc) { + // this neuron's value isn't + // at its final form, don't try to evaluate it. + // the connection that finishes the neuron + // will evaluate it. return; } - while !cache.is_ready(loc) { - // essentially spinlocks until the dependency tasks are complete, - // while letting this thread do some work on random tasks. - rayon::yield_now(); - } - let val = cache.get(loc); let n = self.get_neuron(loc); @@ -178,32 +185,37 @@ impl NeuralNetwork { } /// Split a [`Connection`] into two of the same weight, joined by a new [`Neuron`] in the hidden layer(s). + #[cfg_attr(feature = "tracing", instrument)] pub fn split_connection(&mut self, connection: Connection, rng: &mut impl Rng) { let newloc = NeuronLocation::Hidden(self.hidden_layers.len()); let a = self.get_neuron_mut(connection.from); - let weight = unsafe { a.remove_connection(connection.to) }.unwrap(); + let weight = unsafe { a.remove_connection(&connection.to) }.unwrap(); a.outputs.push((newloc, weight)); - let n = Neuron::new(vec![(connection.to, weight)], NeuronScope::HIDDEN, rng); + let mut n = Neuron::new(vec![(connection.to, weight)], NeuronScope::HIDDEN, rng); + n.input_count += 1; self.hidden_layers.push(n); + + self.total_connections += 1; } - /// Adds a connection but does not check for cyclic linkages. + /// Adds a connection but does not check for cyclic linkages or update [`total_connections`][NeuralNetwork::total_connections]. /// /// # Safety /// This is marked as unsafe because it could cause a hang/livelock when predicting due to cyclic linkage. - /// There is no actual UB or unsafe code associated with it. + /// There is no actual UB or unsafe code associated with it. It does handle [`Neuron::input_count`] properly. pub unsafe fn add_connection_raw(&mut self, connection: Connection, weight: f32) { let a = self.get_neuron_mut(connection.from); a.outputs.push((connection.to, weight)); - // let b = self.get_neuron_mut(connection.to); - // b.inputs.insert(connection.from); + let b = self.get_neuron_mut(connection.to); + b.input_count += 1; } /// Returns false if the connection is cyclic. + #[cfg_attr(feature = "tracing", instrument)] pub fn is_connection_safe(&self, connection: Connection) -> bool { let mut visited = HashSet::from([connection.from]); @@ -211,6 +223,7 @@ impl NeuralNetwork { } // TODO maybe parallelize + #[cfg_attr(feature = "tracing", instrument)] fn dfs(&self, visited: &mut HashSet, current: NeuronLocation) -> bool { if !visited.insert(current) { return false; @@ -236,22 +249,32 @@ impl NeuralNetwork { self.add_connection_raw(connection, weight); } + self.total_connections += 1; + true } /// Mutates a connection's weight. pub fn mutate_weight(&mut self, connection: Connection, rng: &mut impl Rng) { - let rate = self.mutation_settings.weight_mutation_amount; + let max = self.mutation_settings.weight_mutation_amount; let n = self.get_neuron_mut(connection.from); - n.mutate_weight(connection.to, rate, rng).unwrap(); + n.mutate_weight(connection.to, max, rng).unwrap(); } /// Get a random valid location within the network. pub fn random_location(&self, rng: &mut impl Rng) -> NeuronLocation { - match rng.gen_range(0..3) { - 0 => NeuronLocation::Input(rng.gen_range(0..self.input_layer.len())), - 1 => NeuronLocation::Hidden(rng.gen_range(0..self.hidden_layers.len())), - 2 => NeuronLocation::Output(rng.gen_range(0..self.output_layer.len())), + if self.hidden_layers.is_empty() { + return match rng.random_range(0..2) { + 0 => NeuronLocation::Input(rng.random_range(0..self.input_layer.len())), + 1 => NeuronLocation::Output(rng.random_range(0..self.output_layer.len())), + _ => unreachable!(), + }; + } + + match rng.random_range(0..3) { + 0 => NeuronLocation::Input(rng.random_range(0..self.input_layer.len())), + 1 => NeuronLocation::Hidden(rng.random_range(0..self.hidden_layers.len())), + 2 => NeuronLocation::Output(rng.random_range(0..self.output_layer.len())), _ => unreachable!(), } } @@ -261,28 +284,69 @@ impl NeuralNetwork { &self, rng: &mut impl Rng, scope: NeuronScope, - ) -> NeuronLocation { - let loc = self.random_location(rng); + ) -> Option { + let components: Vec<_> = scope.iter().collect(); + + match components[rng.random_range(0..components.len())] { + NeuronScope::INPUT => Some(NeuronLocation::Input(rng.random_range(0..I))), + NeuronScope::HIDDEN => { + if self.hidden_layers.is_empty() { + None + } else { + Some(NeuronLocation::Hidden( + rng.random_range(0..self.hidden_layers.len()), + )) + } + } + NeuronScope::OUTPUT => Some(NeuronLocation::Output(rng.random_range(0..O))), + _ => unreachable!(), + } + } - // this is a lazy and slow way of donig it, TODO better version. - if !scope.contains(NeuronScope::from(loc)) { - return self.random_location_in_scope(rng, scope); + /// Gets a random connection and weight from the neural network. + pub fn random_connection(&self, rng: &mut impl Rng) -> Option<(Connection, f32)> { + if self.total_connections == 0 { + return None; } - loc + let from = self.random_location(rng); + + let n = self.get_neuron(from); + if n.outputs.is_empty() { + return self.random_connection(rng); + } + + let (to, weight) = n.random_output(rng); + + Some((Connection { from, to }, weight)) } /// Remove a connection and any hanging neurons caused by the deletion. - /// Returns whether there was a hanging neuron. + /// Returns whether a hanging neuron (i.e. a neuron with no inputs) was removed. + #[cfg_attr(feature = "tracing", instrument)] pub fn remove_connection(&mut self, connection: Connection) -> bool { + if self.get_neuron(connection.to).input_count == 0 { + #[cfg(feature = "tracing")] + warn!("erroneous network: {self:#?}"); + self.recalculate_connections(); + } + let a = self.get_neuron_mut(connection.from); - unsafe { a.remove_connection(connection.to) }.unwrap(); + unsafe { a.remove_connection(&connection.to) }.unwrap(); + + // if connection.from.is_hidden() && a.outputs.len() == 0 { + // // removes neurons with no outputs + // // TODO return whether this was remove + // self.remove_neuron(connection.from); + // } + + self.total_connections -= 1; let b = self.get_neuron_mut(connection.to); b.input_count -= 1; - if b.input_count == 0 { - self.remove_neuron(connection.to); + if b.input_count == 0 && connection.to.is_hidden() { + self.remove_neuron(&connection.to); return true; } @@ -290,25 +354,38 @@ impl NeuralNetwork { } /// Remove a neuron and downshift all connection indexes to compensate for it. - pub fn remove_neuron(&mut self, loc: impl AsRef) { - let loc = loc.as_ref(); - if !loc.is_hidden() { - panic!("Can only remove neurons from hidden layer"); - } + #[cfg_attr(feature = "tracing", instrument)] + pub fn remove_neuron(&mut self, loc: &NeuronLocation) { + if let NeuronLocation::Hidden(i) = loc { + let n = self.hidden_layers.remove(*i); - unsafe { - self.downshift_connections(loc.unwrap()); + self.total_connections -= n.outputs.len(); + + for (output, _) in n.outputs { + let n2 = self.get_neuron_mut(output); + n2.input_count -= 1; + } + + unsafe { + self.downshift_connections(loc.unwrap()); + } + } else { + panic!("Can only remove neurons from hidden layer"); } } + #[cfg_attr(feature = "tracing", instrument)] unsafe fn downshift_connections(&mut self, i: usize) { - self.input_layer - .par_iter_mut() - .for_each(|n| n.downshift_outputs(i)); + let removed_connections = AtomicUsize::new(0); + self.input_layer.par_iter_mut().for_each(|n| { + removed_connections.fetch_add(n.handle_removed(i), Ordering::SeqCst); + }); + + self.hidden_layers.par_iter_mut().for_each(|n| { + removed_connections.fetch_add(n.handle_removed(i), Ordering::SeqCst); + }); - self.hidden_layers - .par_iter_mut() - .for_each(|n| n.downshift_outputs(i)); + self.total_connections -= removed_connections.into_inner(); } // TODO maybe more parallelism and pass Connection info. @@ -325,91 +402,167 @@ impl NeuralNetwork { unsafe fn clear_input_counts(&mut self) { // not sure whether all this parallelism is necessary or if it will just generate overhead - // rayon::scope(|s| { - // s.spawn(|_| self.input_layer.par_iter_mut().for_each(|n| n.input_count = 0)); - // s.spawn(|_| self.hidden_layers.par_iter_mut().for_each(|n| n.input_count = 0)); - // s.spawn(|_| self.output_layer.par_iter_mut().for_each(|n| n.input_count = 0)); - // }); - - self.input_layer - .par_iter_mut() - .for_each(|n| n.input_count = 0); - self.hidden_layers - .par_iter_mut() - .for_each(|n| n.input_count = 0); - self.output_layer - .par_iter_mut() - .for_each(|n| n.input_count = 0); - } - - /// Recalculates the [`input_count`][`Neuron::input_count`] field for all neurons in the network. - pub fn recalculate_input_counts(&mut self) { + rayon::scope(|s| { + s.spawn(|_| { + self.input_layer + .par_iter_mut() + .for_each(|n| n.input_count = 0) + }); + s.spawn(|_| { + self.hidden_layers + .par_iter_mut() + .for_each(|n| n.input_count = 0) + }); + s.spawn(|_| { + self.output_layer + .par_iter_mut() + .for_each(|n| n.input_count = 0) + }); + }); + } + + /// Recalculates the [`input_count`][`Neuron::input_count`] field for all neurons in the network, + /// as well as the [`total_connections`][`NeuralNetwork::total_connections`] field on the NeuralNetwork. + /// Deletes any hidden layer neurons with an [`input_count`][`Neuron::input_count`] of 0. + #[cfg_attr(feature = "tracing", instrument)] + pub fn recalculate_connections(&mut self) { + // TODO optimization/parallelization. unsafe { self.clear_input_counts() }; + self.total_connections = 0; + for i in 0..I { - for j in 0..self.input_layer[i].outputs.len() { + let conns = self.input_layer[i].outputs.len(); + self.total_connections += conns; + for j in 0..conns { let (loc, _) = self.input_layer[i].outputs[j]; self.get_neuron_mut(loc).input_count += 1; } } for i in 0..self.hidden_layers.len() { - for j in 0..self.hidden_layers[i].outputs.len() { + let conns = self.hidden_layers[i].outputs.len(); + self.total_connections += conns; + for j in 0..conns { let (loc, _) = self.hidden_layers[i].outputs[j]; self.get_neuron_mut(loc).input_count += 1; } } - } -} -impl RandomlyMutable for NeuralNetwork { - fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { - if rng.gen::() <= rate { - // split connection - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let n = self.get_neuron(from); - let (to, _) = n.random_output(rng); + // delete hanging neurons + let mut i = 0; + while i < self.hidden_layers.len() { + let neuron = self.get_neuron(NeuronLocation::Hidden(i)); + + if neuron.input_count == 0 { + self.hidden_layers.remove(i); + unsafe { self.downshift_connections(i) }; + continue; + } - self.split_connection(Connection { from, to }, rng); + i += 1; } + } + + /// Randomly mutates all weights in the network + /// in parallel using [`ThreadRng`][rand::prelude::ThreadRng]. + pub fn mutate_weights(&mut self, rate: f32) { + self.map_weights(|w| { + // TODO maybe `Send`able rng. + let mut rng = rand::rng(); - if rng.gen::() <= rate { - // add connection - let weight = rng.gen::(); + if rng.random::() <= rate { + *w += rng.random_range(-rate..rate); + } + }); + } - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); + /// Creates a random valid connection, if one can be made. + pub fn add_random_connection(&mut self, rng: &mut impl Rng) -> Option<(Connection, f32)> { + #[cfg(feature = "tracing")] + trace!("adding connection"); + let weight = rng.random::(); + + // TODO make this not look nested and gross + if let Some(from) = self.random_location_in_scope(rng, !NeuronScope::OUTPUT) { + if let Some(to) = self.random_location_in_scope(rng, !NeuronScope::INPUT) { + let mut connection = Connection { from, to }; + while !self.add_connection(connection, weight) { + let from = self + .random_location_in_scope(rng, !NeuronScope::OUTPUT) + .unwrap(); + let to = self + .random_location_in_scope(rng, !NeuronScope::INPUT) + .unwrap(); + connection = Connection { from, to }; + } - let mut connection = Connection { from, to }; - while !self.add_connection(connection, weight) { - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let to = self.random_location_in_scope(rng, !NeuronScope::INPUT); - connection = Connection { from, to }; + return Some((connection, weight)); } } - if rng.gen::() <= rate { - // remove connection + #[cfg(feature = "tracing")] + trace!("No possible connections"); + + None + } + + /// Removes a random connnection from the network and returns it, if there are any. + pub fn remove_random_connection(&mut self, rng: &mut impl Rng) -> Option<(Connection, f32)> { + #[cfg(feature = "tracing")] + trace!("removing random connection"); + + if let Some(output) = self.random_connection(rng) { + self.remove_connection(output.0); + + return Some(output); + } + + #[cfg(feature = "tracing")] + trace!("no connections to remove"); - let from = self.random_location_in_scope(rng, !NeuronScope::OUTPUT); - let a = self.get_neuron(from); - let (to, _) = a.random_output(rng); + None + } + + /// Splits a random connection in the network, if there are any. + #[cfg_attr(feature = "tracing", instrument)] + pub fn split_random_connection(&mut self, rng: &mut impl Rng) -> bool { + #[cfg(feature = "tracing")] + trace!("splitting random connection"); - self.remove_connection(Connection { from, to }); + if let Some((conn, _)) = self.random_connection(rng) { + self.split_connection(conn, rng); + return true; } - self.map_weights(|w| { - // TODO maybe `Send`able rng. - let mut rng = rand::thread_rng(); + #[cfg(feature = "tracing")] + trace!("no connections to split"); - if rng.gen::() <= rate { - *w += rng.gen_range(-rate..rate); - } - }); + false + } +} + +impl RandomlyMutable for NeuralNetwork { + #[cfg_attr(feature = "tracing", instrument)] + fn mutate(&mut self, rate: f32, rng: &mut impl Rng) { + self.mutate_weights(rate); + + if rng.random::() <= rate && self.total_connections > 0 { + self.split_random_connection(rng); + } + + if rng.random::() <= rate || self.total_connections == 0 { + self.add_random_connection(rng); + } + + if rng.random::() <= rate && self.total_connections > 0 { + self.remove_random_connection(rng); + } } } impl DivisionReproduction for NeuralNetwork { + #[cfg_attr(feature = "tracing", instrument)] fn divide(&self, rng: &mut impl Rng) -> Self { let mut child = self.clone(); @@ -421,49 +574,70 @@ impl DivisionReproduction for NeuralNetwork CrossoverReproduction for NeuralNetwork { - fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { + #[cfg_attr(feature = "tracing", instrument)] + fn crossover(&self, other: &Self, rng: &mut impl Rng) -> Self { let mut output_layer = self.output_layer.clone(); for (i, n) in output_layer.iter_mut().enumerate() { - if rng.gen::() >= 0.5 { + if rng.random::() >= 0.5 { *n = other.output_layer[i].clone(); } } - let hidden_len = self.hidden_layers.len().max(other.hidden_layers.len()); + // TODO cleaner code + let hidden_len; + let bigger; + let smaller; + + if self.hidden_layers.len() >= other.hidden_layers.len() { + hidden_len = self.hidden_layers.len(); + bigger = self; + smaller = other; + } else { + hidden_len = other.hidden_layers.len(); + bigger = other; + smaller = self; + } + let mut hidden_layers = Vec::with_capacity(hidden_len); for i in 0..hidden_len { - if rng.gen::() >= 0.5 { - if let Some(n) = self.hidden_layers.get(i) { + if rng.random::() >= 0.5 { + if let Some(n) = smaller.hidden_layers.get(i) { let mut n = n.clone(); + + // TODO merge these two functions so that it isn't + // doing extra work by looping twice. n.prune_invalid_outputs(hidden_len, O); + n.prune_duplicate_outputs(); - hidden_layers[i] = n; + hidden_layers.push(n); continue; } } - let mut n = other.hidden_layers[i].clone(); + // either `bigger` won the 50/50 or `i >= smaller.hidden_layers.len()` + + let mut n = bigger.hidden_layers[i].clone(); n.prune_invalid_outputs(hidden_len, O); + n.prune_duplicate_outputs(); - hidden_layers[i] = n; + hidden_layers.push(n); } let mut input_layer = self.input_layer.clone(); for (i, n) in input_layer.iter_mut().enumerate() { - if rng.gen::() >= 0.5 { + if rng.random::() >= 0.5 { *n = other.input_layer[i].clone(); } n.prune_invalid_outputs(hidden_len, O); } // crossover mutation settings just in case. - let mutation_settings = if rng.gen::() >= 0.5 { + let mutation_settings = if rng.random::() >= 0.5 { self.mutation_settings.clone() } else { other.mutation_settings.clone() @@ -474,11 +648,12 @@ impl CrossoverReproduction for NeuralNetwork, activation_fn: ActivationFn, @@ -535,13 +711,14 @@ impl Neuron { Self { input_count: 0, outputs, - bias: rng.gen(), + bias: rng.random(), activation_fn, } } /// Creates a new neuron with the given output locations. /// Chooses a random activation function within the specified scope. + #[cfg_attr(feature = "tracing", instrument)] pub fn new( outputs: Vec<(NeuronLocation, f32)>, current_scope: NeuronScope, @@ -550,6 +727,8 @@ impl Neuron { let reg = ACTIVATION_REGISTRY.read().unwrap(); let activations = reg.activations_in_scope(current_scope); + // dbg!(current_scope, &activations); + Self::new_with_activations(outputs, activations, rng) } @@ -570,7 +749,7 @@ impl Neuron { Self::new_with_activation( outputs, - activations.remove(rng.gen_range(0..activations.len())), + activations.remove(rng.random_range(0..activations.len())), rng, ) } @@ -581,10 +760,10 @@ impl Neuron { } /// Get the weight of the provided output location. Returns `None` if not found. - pub fn get_weight(&self, output: impl AsRef) -> Option { - let loc = *output.as_ref(); + #[cfg_attr(feature = "tracing", instrument)] + pub fn get_weight(&self, output: &NeuronLocation) -> Option { for out in &self.outputs { - if out.0 == loc { + if out.0 == *output { return Some(out.1); } } @@ -597,12 +776,12 @@ impl Neuron { /// # Safety /// This is marked as unsafe because it will not update the destination's [`input_count`][Neuron::input_count]. /// Similar to [`add_connection_raw`][NeuralNetwork::add_connection_raw], this does not mean UB or anything. - pub unsafe fn remove_connection(&mut self, output: impl AsRef) -> Option { - let loc = *output.as_ref(); + #[cfg_attr(feature = "tracing", instrument)] + pub unsafe fn remove_connection(&mut self, output: &NeuronLocation) -> Option { let mut i = 0; while i < self.outputs.len() { - if self.outputs[i].0 == loc { + if self.outputs[i].0 == *output { return Some(self.outputs.remove(i).1); } i += 1; @@ -615,7 +794,7 @@ impl Neuron { pub fn mutate_weight( &mut self, output: impl AsRef, - rate: f32, + max: f32, rng: &mut impl Rng, ) -> Option { let loc = *output.as_ref(); @@ -624,7 +803,7 @@ impl Neuron { while i < self.outputs.len() { let o = &mut self.outputs[i]; if o.0 == loc { - o.1 += rng.gen_range(-rate..rate); + o.1 += rng.random_range(-max..max); return Some(o.1); } @@ -637,19 +816,32 @@ impl Neuron { /// Get a random output location and weight. pub fn random_output(&self, rng: &mut impl Rng) -> (NeuronLocation, f32) { - self.outputs[rng.gen_range(0..self.outputs.len())] + if self.outputs.is_empty() { + // TODO option type + panic!("cannot sample outputs from a neuron with no outputs"); + } + + self.outputs[rng.random_range(0..self.outputs.len())] } - pub(crate) fn downshift_outputs(&mut self, i: usize) { - // TODO par_iter_mut instead of replace + #[cfg_attr(feature = "tracing", instrument)] + pub(crate) fn handle_removed(&mut self, i: usize) -> usize { + // TODO par_iter_mut or something instead of replace + let removed = AtomicUsize::new(0); replace_with_or_abort(&mut self.outputs, |o| { o.into_par_iter() - .map(|(loc, w)| match loc { - NeuronLocation::Hidden(j) if j > i => (NeuronLocation::Hidden(j - 1), w), - _ => (loc, w), + .filter_map(|(loc, w)| match loc { + NeuronLocation::Hidden(j) if j == i => { + removed.fetch_add(1, Ordering::SeqCst); + None + } + NeuronLocation::Hidden(j) if j > i => Some((NeuronLocation::Hidden(j - 1), w)), + _ => Some((loc, w)), }) .collect() }); + + removed.into_inner() } /// Removes any outputs pointing to a nonexistent neuron. @@ -657,6 +849,32 @@ impl Neuron { self.outputs .retain(|(loc, _)| output_exists(*loc, hidden_len, output_len)); } + + /// Removes any connections pointing to the same neuron. + /// Retains the first connection. + pub(crate) fn prune_duplicate_outputs(&mut self) { + // TODO optimize so it isn't O(N^2) + + // outer loop probably doesn't need to be a while + // but i don't think range iterators update when len + // changes. + let mut i = 0; + while i < self.outputs.len() { + let a = self.outputs[i].0; + + let mut j = 0; + while j < self.outputs.len() { + let b = self.outputs[j].0; + if a == b { + self.outputs.remove(j); + continue; + } + j += 1; + } + + i += 1; + } + } } /// A pseudo-pointer of sorts that is used for caching. @@ -716,9 +934,6 @@ pub struct NeuronCache { /// The number of inputs that have finished evaluating. pub finished_inputs: AtomicUsize, - - /// Whether or not a thread has claimed this neuron to work on it. - pub claimed: AtomicBool, } impl NeuronCache { @@ -739,7 +954,6 @@ impl From<&Neuron> for NeuronCache { value: AtomicF32::new(value.bias), expected_inputs: value.input_count, finished_inputs: AtomicUsize::new(0), - claimed: AtomicBool::new(false), } } } @@ -769,6 +983,7 @@ impl NeuralNetCache { /// Adds a value to the neuron at the specified location and increments [`finished_inputs`][NeuronCache::finished_inputs]. pub fn add(&self, loc: impl AsRef, n: f32) -> f32 { + // TODO panic or something if there's too many finished_inputs already. match loc.as_ref() { NeuronLocation::Input(i) => self.input_layer[*i].value.fetch_add(n, Ordering::SeqCst), NeuronLocation::Hidden(i) => { @@ -791,15 +1006,15 @@ impl NeuralNetCache { match loc.as_ref() { NeuronLocation::Input(i) => { let c = &self.input_layer[*i]; - c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + c.expected_inputs == c.finished_inputs.load(Ordering::SeqCst) } NeuronLocation::Hidden(i) => { let c = &self.hidden_layers[*i]; - c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + c.expected_inputs == c.finished_inputs.load(Ordering::SeqCst) } NeuronLocation::Output(i) => { let c = &self.output_layer[*i]; - c.expected_inputs >= c.finished_inputs.load(Ordering::SeqCst) + c.expected_inputs == c.finished_inputs.load(Ordering::SeqCst) } } } @@ -821,24 +1036,6 @@ impl NeuralNetCache { output.try_into().unwrap() } - - /// Attempts to claim a neuron. Returns false if it has already been claimed. - pub fn claim(&self, loc: impl AsRef) -> bool { - match loc.as_ref() { - NeuronLocation::Input(i) => self.input_layer[*i] - .claimed - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - .is_ok(), - NeuronLocation::Hidden(i) => self.hidden_layers[*i] - .claimed - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - .is_ok(), - NeuronLocation::Output(i) => self.output_layer[*i] - .claimed - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - .is_ok(), - } - } } impl From<&NeuralNetwork> for NeuralNetCache { diff --git a/src/tests.rs b/src/tests.rs index 825cdee..2f4eff0 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,5 +1,4 @@ -use crate::*; -use rand::prelude::*; +use crate::{activation::NeuronScope, *}; // no support for tuple structs derive in genetic-rs yet :( #[derive(Debug, Clone, PartialEq)] @@ -14,22 +13,28 @@ impl RandomlyMutable for Agent { } impl DivisionReproduction for Agent { - fn divide(&self, rng: &mut impl rand::Rng) -> Self { + fn divide(&self, rng: &mut impl Rng) -> Self { Self(self.0.divide(rng)) } } impl CrossoverReproduction for Agent { - fn crossover(&self, other: &Self, rng: &mut impl rand::Rng) -> Self { + fn crossover(&self, other: &Self, rng: &mut impl Rng) -> Self { Self(self.0.crossover(&other.0, rng)) } } +impl GenerateRandom for Agent { + fn gen_random(rng: &mut impl Rng) -> Self { + Self(NeuralNetwork::new(MutationSettings::default(), rng)) + } +} + struct GuessTheNumber(f32); impl GuessTheNumber { fn new(rng: &mut impl Rng) -> Self { - Self(rng.gen()) + Self(rng.random()) } fn guess(&self, n: f32) -> Option { @@ -47,7 +52,7 @@ impl GuessTheNumber { } fn fitness(agent: &Agent) -> f32 { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let mut fitness = 0.; @@ -99,11 +104,7 @@ fn fitness(agent: &Agent) -> f32 { #[test] fn division() { - let mut rng = rand::thread_rng(); - - let starting_genomes = (0..100) - .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) - .collect(); + let starting_genomes = Vec::gen_random(100); let mut sim = GeneticSim::new(starting_genomes, fitness, division_pruning_nextgen); @@ -112,11 +113,7 @@ fn division() { #[test] fn crossover() { - let mut rng = rand::thread_rng(); - - let starting_genomes = (0..100) - .map(|_| Agent(NeuralNetwork::new(MutationSettings::default(), &mut rng))) - .collect(); + let starting_genomes = Vec::gen_random(100); let mut sim = GeneticSim::new(starting_genomes, fitness, crossover_pruning_nextgen); @@ -126,7 +123,7 @@ fn crossover() { #[cfg(feature = "serde")] #[test] fn serde() { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let net: NeuralNetwork<5, 10> = NeuralNetwork::new(MutationSettings::default(), &mut rng); let text = serde_json::to_string(&net).unwrap(); @@ -151,8 +148,6 @@ fn neural_net_cache_sync() { for i in 0..2 { let input_loc = NeuronLocation::Input(i); - assert!(cache.claim(&input_loc)); - for j in 0..3 { cache.add( NeuronLocation::Hidden(j), @@ -165,7 +160,6 @@ fn neural_net_cache_sync() { let hidden_loc = NeuronLocation::Hidden(i); assert!(cache.is_ready(&hidden_loc)); - assert!(cache.claim(&hidden_loc)); for j in 0..2 { cache.add( @@ -177,3 +171,202 @@ fn neural_net_cache_sync() { assert_eq!(cache.output(), [2.0688455, 2.0688455]); } + +fn small_test_network() -> NeuralNetwork<1, 1> { + let mut rng = rand::rng(); + + let input = Neuron::new( + vec![ + (NeuronLocation::Hidden(0), 1.), + (NeuronLocation::Hidden(1), 1.), + (NeuronLocation::Hidden(2), 1.), + ], + NeuronScope::INPUT, + &mut rng, + ); + + let mut hidden = Neuron::new( + vec![(NeuronLocation::Output(0), 1.)], + NeuronScope::HIDDEN, + &mut rng, + ); + hidden.input_count = 1; + + let mut output = Neuron::new(vec![], NeuronScope::OUTPUT, &mut rng); + output.input_count = 3; + + NeuralNetwork { + input_layer: [input], + hidden_layers: vec![hidden; 3], + output_layer: [output], + mutation_settings: MutationSettings::default(), + total_connections: 6, + } +} + +#[test] +fn remove_neuron() { + let mut network = small_test_network(); + + network.remove_neuron(&NeuronLocation::Hidden(1)); + + assert_eq!(network.total_connections, 4); + + let expected = vec![NeuronLocation::Hidden(0), NeuronLocation::Hidden(1)]; + let got: Vec<_> = network.input_layer[0].outputs.iter().map(|c| c.0).collect(); + + assert_eq!(got, expected); +} + +#[test] +fn recalculate_connections() { + let mut rng = rand::rng(); + + let input = Neuron::new( + vec![ + (NeuronLocation::Hidden(0), 1.), + (NeuronLocation::Hidden(1), 1.), + (NeuronLocation::Hidden(2), 1.), + ], + NeuronScope::INPUT, + &mut rng, + ); + + let hidden = Neuron::new( + vec![(NeuronLocation::Output(0), 1.)], + NeuronScope::HIDDEN, + &mut rng, + ); + + let output = Neuron::new(vec![], NeuronScope::OUTPUT, &mut rng); + + let mut network = NeuralNetwork { + input_layer: [input], + hidden_layers: vec![hidden; 3], + output_layer: [output], + mutation_settings: MutationSettings::default(), + total_connections: 0, + }; + + network.recalculate_connections(); + + assert_eq!(network.total_connections, 6); + + for n in &network.hidden_layers { + assert_eq!(n.input_count, 1); + } + + assert_eq!(network.output_layer[0].input_count, 3); +} + +#[test] +fn add_connection() { + let mut network = small_test_network(); + + assert!(network.add_connection( + Connection { + from: NeuronLocation::Hidden(0), + to: NeuronLocation::Hidden(1), + }, + 1. + )); + + assert_eq!(network.total_connections, 7); + assert_eq!(network.hidden_layers[1].input_count, 2); + + assert!(!network.add_connection( + Connection { + from: NeuronLocation::Hidden(1), + to: NeuronLocation::Hidden(0) + }, + 1. + )); + + assert_eq!(network.total_connections, 7); + + assert!(network.add_connection( + Connection { + from: NeuronLocation::Hidden(1), + to: NeuronLocation::Hidden(2), + }, + 1. + )); + + assert!(!network.add_connection( + Connection { + from: NeuronLocation::Hidden(2), + to: NeuronLocation::Hidden(0), + }, + 1. + )); +} + +#[test] +fn remove_connection() { + let mut network = small_test_network(); + + assert!(!network.remove_connection(Connection { + from: NeuronLocation::Hidden(0), + to: NeuronLocation::Output(0), + })); + + assert_eq!(network.total_connections, 5); + + assert!(network.remove_connection(Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Hidden(1), + })); + + assert_eq!(network.total_connections, 3); + assert_eq!(network.hidden_layers.len(), 2); +} + +#[test] +fn random_location_in_scope() { + let mut rng = rand::rng(); + let mut network = small_test_network(); + + assert_eq!( + network.random_location_in_scope(&mut rng, NeuronScope::INPUT), + Some(NeuronLocation::Input(0)) + ); + + // TODO `assert_matches` when it is stable + assert!(matches!( + network.random_location_in_scope(&mut rng, NeuronScope::HIDDEN), + Some(NeuronLocation::Hidden(_)) + )); + + let multi = network.random_location_in_scope(&mut rng, !NeuronScope::INPUT); + assert!( + matches!(multi, Some(NeuronLocation::Hidden(_))) + || matches!(multi, Some(NeuronLocation::Output(_))) + ); + + network.hidden_layers = vec![]; + assert!(network + .random_location_in_scope(&mut rng, NeuronScope::HIDDEN) + .is_none()); +} + +#[test] +fn split_connection() { + let mut rng = rand::rng(); + let mut network = small_test_network(); + + network.split_connection( + Connection { + from: NeuronLocation::Input(0), + to: NeuronLocation::Hidden(1), + }, + &mut rng, + ); + + assert_eq!(network.total_connections, 7); + + let n = &network.hidden_layers[3]; + assert_eq!(n.outputs[0].0, NeuronLocation::Hidden(1)); + assert_eq!(n.input_count, 1); +} + +// TODO test every method