Skip to content

Commit

Permalink
Updating MNIST example project
Browse files Browse the repository at this point in the history
  • Loading branch information
mjovanc committed Dec 6, 2024
1 parent ff2b6c6 commit 1d4331f
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions examples/image_classification/mnist/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ use deltaml::data::MnistDataset;
use deltaml::losses::SparseCategoricalCrossEntropyLoss;
use deltaml::neuralnet::Sequential;
use deltaml::neuralnet::{Dense, Flatten};
use deltaml::neuralnet::layers::{Conv2D, MaxPooling2D};
use deltaml::optimizers::Adam;

#[tokio::main]
async fn main() {
// Create a neural network
let mut model = Sequential::new()
.add(Flatten::new(Shape::new(vec![28, 28]))) // Input: 28x28, Output: 784
.add(Dense::new(128, Some(ReluActivation::new()), true)) // Input: 784, Output: 128
.add(Dense::new(10, None::<SoftmaxActivation>, false)); // Output: 10 classes
.add(Conv2D::new(32, 3, 1, 1, Some(Box::new(ReluActivation::new())), true)) // Conv2D layer with 32 filters, kernel size 3x3
.add(MaxPooling2D::new(2, 2)) // MaxPooling2D layer with pool size 2x2
.add(Flatten::new(Shape::new(vec![28, 28, 32]))) // Flatten layer
.add(Dense::new(128, Some(ReluActivation::new()), true)) // Dense layer with 128 units
.add(Dense::new(10, None::<SoftmaxActivation>, false)); // Output layer with 10 classes

// Display the model summary
model.summary();
Expand Down

0 comments on commit 1d4331f

Please sign in to comment.