-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.cpp
66 lines (59 loc) · 2.05 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// Train a neural network classifier.
#include <Eigen/Core>
#include <iostream>
#include "nn.h"
#include "optimizers.h"
#include "utils.h"
int main()
{
std::string path = "wine.data"; // https://archive.ics.uci.edu/ml/datasets/wine
Eigen::MatrixXd csv = LoadCSV<Eigen::MatrixXd>(path);
// train/validation shuffle split
Eigen::MatrixXd x_train;
Eigen::MatrixXd x_val;
Eigen::MatrixXd y_train;
Eigen::MatrixXd y_val;
TrainTestSplit(csv.rightCols(csv.cols() - 1), OneHot(csv.leftCols(1).cast<int>()),\
x_train, x_val, y_train, y_val, 0.3);
std::cout << "Train set size: " << y_train.rows() << std::endl;
std::cout << " Val set size: " << y_val.rows() << std::endl;
std::cout << std::endl;
// scale features to be between 0 and 1
MinMaxScaler scaler(0, 1);
scaler.fit(x_train);
scaler.transform(x_train);
scaler.transform(x_val);
// define network
srand(time(NULL));
Hidden h1(x_train.cols(), 10);
Dropout d1(0.5);
Softmax softmax(10, y_train.cols());
NeuralNet net( &h1, &d1, &softmax );
// train network
size_t batch_size = 8;
size_t epochs = 15;
size_t steps_per_epoch = (y_train.rows() + batch_size - 1) / batch_size;
size_t epoch = 0;
Eigen::MatrixXd x_batch;
Eigen::MatrixXd y_batch;
SGD sgd(net, 0.04, 0.5);
Batcher batcher(batch_size, x_train, y_train);
for ( size_t i = 1; i <= epochs * steps_per_epoch; ++i )
{
batcher.batch(x_batch, y_batch);
sgd.fit(x_batch, y_batch);
if ( !(i % steps_per_epoch) )
{
epoch += 1;
PrintTrainingMetrics(net, epoch, x_train, x_val, y_train, y_val);
sgd.lr(std::max(sgd.lr() - 0.002, 0.001));
if ( epoch == epochs / 2 )
{
sgd.momentum(0.9);
}
}
}
std::cout << "\nFinal Validation Accuracy: " ;
std::cout << Accuracy(Predict(y_val), Predict(net.probs(x_val)));
std::cout << std::endl;
}