-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.lua
144 lines (111 loc) · 3.72 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
----------------------------------------------------------------------
-- original taken from Clement
----------------------------------------------------------------------
require 'optim' -- an optimization package, for online and batch methods
----------------------------------------------------------------------
-- Model + Loss:
local t = require 'model'
local model = t.model
local loss = t.loss
local dropout = t.dropout
local lrs = t.lrs
local wds = t.wds
----------------------------------------------------------------------
print '==> defining some tools'
-- classes
local classes = {'airplane', 'automobile', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
-- This matrix records the current confusion across classes
local confusion = optim.ConfusionMatrix(classes)
-- Log results to files
local trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
----------------------------------------------------------------------
print '==> flattening model parameters'
-- Retrieve parameters and gradients:
-- this extracts and flattens all the trainable parameters of the mode
-- into a 1-dim vector
local w,dE_dw = model:getParameters()
local optimState = {
learningRate = opt.learningRate,
momentum = opt.momentum,
dampening = 0,
weightDecay = opt.weightDecay,
learningRateDecay = opt.learningRateDecay,
learningRates = lrs,
weightDecays = wds
}
print '==> allocating minibatch memory'
local x = torch.CudaTensor(opt.batchSize,3,32,32)
local yt = torch.CudaTensor(opt.batchSize)
----------------------------------------------------------------------
print '==> defining training procedure'
local epoch
local function train(trainData)
-- epoch tracker
epoch = epoch or 1
local time = sys.clock()
-- shuffle at each epoch
local shuffle = torch.randperm(trainData:size())
-- do one epoch
print('==> doing epoch on training data:')
print("==> online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ']')
for t = 1,trainData:size(),opt.batchSize do
-- disp progress
xlua.progress(t, trainData:size())
collectgarbage()
-- dropout on
for _,d in ipairs(dropout) do
d.train = true
end
-- batch fits?
if (t + opt.batchSize - 1) > trainData:size() then
break
end
-- create batch
local idx = 1
for i = t,t+opt.batchSize-1 do
x[idx] = trainData.data[shuffle[i]]
yt[idx] = trainData.labels[shuffle[i]]
idx = idx + 1
end
-- create closure to evaluate f(X) and df/dX
local eval_E = function(w)
-- reset gradients
dE_dw:zero()
-- evaluate function for complete mini batch
local y = model:forward(x)
-- estimate df/dW
local dE_dy = loss:backward(y,yt)
model:backward(x,dE_dy)
return 0,dE_dw
end
optim.sgd(eval_E, w, optimState)
-- update confusion
-- dropout off
for _,d in ipairs(dropout) do
d.train = false
end
local y = model:forward(x)
for i = 1,opt.batchSize do
confusion:add(y[i],yt[i])
end
end
-- time taken
time = sys.clock() - time
time = time / trainData:size()
print("\n==> time to learn 1 sample = " .. (time*1000) .. 'ms')
-- print confusion matrix
print(tostring(confusion))
-- update logger/plot
trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100}
-- save/log current net
local filename = paths.concat(opt.save, 'model.net')
os.execute('mkdir -p ' .. sys.dirname(filename))
print('==> saving model to '..filename)
torch.save(filename, model)
-- next epoch
confusion:zero()
epoch = epoch + 1
end
-- Export:
return train