Skip to content
This repository has been archived by the owner on Mar 12, 2020. It is now read-only.

Example MNIST Training

Deepak Kumar Battini edited this page Nov 13, 2017 · 4 revisions

Import the namespaces for this example

using SiaNet;
using SiaNet.Common;
using SiaNet.Model;
using SiaNet.Model.Layers;
using SiaNet.Model.Optimizers;

Initialize variables

int[] imageDim = new int[] { 28, 28, 1 };
int numClasses = 10;

Select the device CPU or GPU

GlobalParameters.Device = CNTK.DeviceDescriptor.CPUDevice;

Configure logging (Optional step)

Logging.OnWriteLog += Logging_OnWriteLog;

Download the sample MNIST dataset (Optional step if already downloaded. It will check and download anyway)

Downloader.DownloadSample(SampleDataset.MNIST);

Get the downloaded sample path

var samplePath = Downloader.GetSamplePath(SampleDataset.MNIST);

Load the train and validation dataset

var train = ImageDataGenerator.FlowFromText(samplePath.Train);
var validation = ImageDataGenerator.FlowFromText(samplePath.Test);

Create new instance of the sequential model

Sequential model = new Sequential();

Configure the training callbacks

model.OnEpochEnd += Model_OnEpochEnd;
model.OnTrainingEnd += Model_OnTrainingEnd;
model.OnBatchEnd += Model_OnBatchEnd;

Build the model by stacking various neural network layers

model.Add(new Conv2D(Tuple.Create(imageDim[0], imageDim[1], imageDim[2]), 4, Tuple.Create(3, 3), Tuple.Create(2, 2), activation: OptActivations.None, weightInitializer: OptInitializers.Xavier, useBias: true, biasInitializer: OptInitializers.Ones));
model.Add(new MaxPool2D(Tuple.Create(3, 3)));
model.Add(new Conv2D(8, Tuple.Create(3, 3), Tuple.Create(2, 2), activation: OptActivations.None, weightInitializer: OptInitializers.Xavier));
model.Add(new MaxPool2D(Tuple.Create(3, 3)));
model.Add(new Dense(numClasses));

Compile the model and train

model.Compile(new SGD(0.003125), OptLosses.CrossEntropy, OptMetrics.Accuracy);
model.Train(train, 10, 64, null);

Event functions

private static void Model_OnTrainingEnd(Dictionary<string, List<double>> trainingResult)
{
   Console.WriteLine("Training completed.");
}
private static void Model_OnEpochEnd(int epoch, uint samplesSeen, double loss, Dictionary<string, double> metrics)
{
   Console.WriteLine(string.Format("Epoch: {0}, Loss: {1}, Accuracy: {2}", epoch, loss, metrics.First().Value));
}
private static void Model_OnBatchEnd(int epoch, int batchNumber, uint samplesSeen, double loss, Dictionary<string, double> metrics)
{
   if (batchNumber % 20 == 0)
   Console.WriteLine(string.Format("Epoch: {0}, Batch: {1}, Loss: {2}, Accuracy: {3}", epoch, batchNumber, loss, metrics.First().Value));
}
private static void Logging_OnWriteLog(string message)
{
    Console.WriteLine("Log: " + message); ;
}
Clone this wiki locally