This repository has been archived by the owner on Mar 12, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 83
Example MNIST Training
Deepak Kumar Battini edited this page Nov 13, 2017
·
4 revisions
using SiaNet;
using SiaNet.Common;
using SiaNet.Model;
using SiaNet.Model.Layers;
using SiaNet.Model.Optimizers;
int[] imageDim = new int[] { 28, 28, 1 };
int numClasses = 10;
GlobalParameters.Device = CNTK.DeviceDescriptor.CPUDevice;
Logging.OnWriteLog += Logging_OnWriteLog;
Download the sample MNIST dataset (Optional step if already downloaded. It will check and download anyway)
Downloader.DownloadSample(SampleDataset.MNIST);
var samplePath = Downloader.GetSamplePath(SampleDataset.MNIST);
var train = ImageDataGenerator.FlowFromText(samplePath.Train);
var validation = ImageDataGenerator.FlowFromText(samplePath.Test);
Sequential model = new Sequential();
model.OnEpochEnd += Model_OnEpochEnd;
model.OnTrainingEnd += Model_OnTrainingEnd;
model.OnBatchEnd += Model_OnBatchEnd;
model.Add(new Conv2D(Tuple.Create(imageDim[0], imageDim[1], imageDim[2]), 4, Tuple.Create(3, 3), Tuple.Create(2, 2), activation: OptActivations.None, weightInitializer: OptInitializers.Xavier, useBias: true, biasInitializer: OptInitializers.Ones));
model.Add(new MaxPool2D(Tuple.Create(3, 3)));
model.Add(new Conv2D(8, Tuple.Create(3, 3), Tuple.Create(2, 2), activation: OptActivations.None, weightInitializer: OptInitializers.Xavier));
model.Add(new MaxPool2D(Tuple.Create(3, 3)));
model.Add(new Dense(numClasses));
model.Compile(new SGD(0.003125), OptLosses.CrossEntropy, OptMetrics.Accuracy);
model.Train(train, 10, 64, null);
private static void Model_OnTrainingEnd(Dictionary<string, List<double>> trainingResult)
{
Console.WriteLine("Training completed.");
}
private static void Model_OnEpochEnd(int epoch, uint samplesSeen, double loss, Dictionary<string, double> metrics)
{
Console.WriteLine(string.Format("Epoch: {0}, Loss: {1}, Accuracy: {2}", epoch, loss, metrics.First().Value));
}
private static void Model_OnBatchEnd(int epoch, int batchNumber, uint samplesSeen, double loss, Dictionary<string, double> metrics)
{
if (batchNumber % 20 == 0)
Console.WriteLine(string.Format("Epoch: {0}, Batch: {1}, Loss: {2}, Accuracy: {3}", epoch, batchNumber, loss, metrics.First().Value));
}
private static void Logging_OnWriteLog(string message)
{
Console.WriteLine("Log: " + message); ;
}