diff --git a/examples/image_classification/mnist/src/main.rs b/examples/image_classification/mnist/src/main.rs index 1111673..ba7e310 100644 --- a/examples/image_classification/mnist/src/main.rs +++ b/examples/image_classification/mnist/src/main.rs @@ -6,15 +6,18 @@ use deltaml::data::MnistDataset; use deltaml::losses::SparseCategoricalCrossEntropyLoss; use deltaml::neuralnet::Sequential; use deltaml::neuralnet::{Dense, Flatten}; +use deltaml::neuralnet::layers::{Conv2D, MaxPooling2D}; use deltaml::optimizers::Adam; #[tokio::main] async fn main() { // Create a neural network let mut model = Sequential::new() - .add(Flatten::new(Shape::new(vec![28, 28]))) // Input: 28x28, Output: 784 - .add(Dense::new(128, Some(ReluActivation::new()), true)) // Input: 784, Output: 128 - .add(Dense::new(10, None::, false)); // Output: 10 classes + .add(Conv2D::new(32, 3, 1, 1, Some(Box::new(ReluActivation::new())), true)) // Conv2D layer with 32 filters, kernel size 3x3 + .add(MaxPooling2D::new(2, 2)) // MaxPooling2D layer with pool size 2x2 + .add(Flatten::new(Shape::new(vec![28, 28, 32]))) // Flatten layer + .add(Dense::new(128, Some(ReluActivation::new()), true)) // Dense layer with 128 units + .add(Dense::new(10, None::, false)); // Output layer with 10 classes // Display the model summary model.summary();