-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.lua
375 lines (295 loc) · 14.6 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
local imageLoader = require('imageLoader')
local torchUtil = require('torchUtil')
--local debugBatchIndices = {[1]=true, [100]=true, [200]=true}
-- local debugBatchIndices = {[5]=true}
local debugBatchIndices = {}
-- Setup a reused optimization state (for adam/sgd).
local optimState = {
learningRate = 0.0
}
-- Learning rate annealing schedule. We will build a new optimizer for
-- each epoch.
--
-- By default we follow a known recipe for a 55-epoch training. If
-- the learningRate command-line parameter has been specified, though,
-- we trust the user is doing something manual, and will use her
-- exact settings for all optimization.
--
-- Return values:
-- diff to apply to optimState,
-- true IFF this is the first epoch of a new regime
local function paramsForEpoch(epoch)
local regimes = {
-- start, end, LR, WD,
{ 1, 1, 1e-3, 0 },
{ 2, 2, 1e-3, 0 },
{ 3, 3, 5e-4, 0 },
{ 4, 10, 4e-5, 0 },
{ 11, 20, 2e-5, 0 },
{ 21, 30, 1e-5, 0 },
{ 31, 40, 5e-6, 0 },
{ 41, 1e8, 1e-6, 0 },
}
for _, row in ipairs(regimes) do
if epoch >= row[1] and epoch <= row[2] then
return { learningRate=row[3], weightDecay=row[4] }, epoch == row[1]
end
end
end
-- Stuff for logging
local trainLogger = nil
local batchNumber -- Current batch in current epoch
local totalBatchCount = 0 -- Total # of batches across all epochs
local epochStats = {}
-- GPU inputs (preallocate)
local grayscaleInputs = torch.CudaTensor()
local RGBTargets = torch.CudaTensor()
local classLabels = torch.CudaTensor()
local randomness = torch.CudaTensor()
local timer = torch.Timer()
local dataTimer = torch.Timer()
-- 4. trainSuperBatch - Used by train() to train a superbatch.
local function trainSuperBatch(model, imgLoader, opt, epoch)
local parameters, gradParameters = model.trainingNet:getParameters()
cutorch.synchronize()
local dataLoadingTime = 0
timer:reset()
local classLossSum, pixelRGBLossSum, contentLossSum, kldLossSum, totalLossSum = 0, 0, 0, 0, 0
local top1, top5 = 0, 0
local feval = function(x)
model.trainingNet:zeroGradParameters()
for superBatch = 1, opt.superBatches do
local loadTimeStart = dataTimer:time().real
local batch = imageLoader.sampleBatch(imgLoader)
local loadTimeEnd = dataTimer:time().real
dataLoadingTime = dataLoadingTime + (loadTimeEnd - loadTimeStart)
local randomnessCPU
if opt.useRandomness then
randomnessCPU = torch.randn(opt.batchSize, 512, 28, 28)
-- -- [Fully-connected bottleneck version]
-- randomnessCPU = torch.randn(opt.batchSize, 4704)
else
-- one lets it make some use of the sigma terms
randomnessCPU = torch.FloatTensor(opt.batchSize, 512, 28, 28):zero():add(1.0)
end
-- transfer over to GPU
grayscaleInputs:resize(batch.grayscaleInputs:size()):copy(batch.grayscaleInputs)
RGBTargets:resize(batch.RGBTargets:size()):copy(batch.RGBTargets)
classLabels:resize(batch.classLabels:size()):copy(batch.classLabels)
randomness:resize(randomnessCPU:size()):copy(randomnessCPU)
local contentTargets = model.vggNet:forward(RGBTargets):clone()
local outputLoss = model.trainingNet:forward({grayscaleInputs, randomness, RGBTargets, contentTargets, classLabels})
local classLoss = outputLoss[1][1]
local pixelRGBLoss = outputLoss[2][1]
local contentLoss = outputLoss[3][1]
local kldLoss = outputLoss[4][1]
classLossSum = classLossSum + classLoss
pixelRGBLossSum = pixelRGBLossSum + pixelRGBLoss
contentLossSum = contentLossSum + contentLoss
kldLossSum = kldLossSum + kldLoss
totalLossSum = totalLossSum + classLoss + pixelRGBLoss + contentLoss + kldLoss
-- Check nans
assert(totalLossSum == totalLossSum, 'NaN in loss!')
local classProbabilities = model.classProbabilities.data.module.output
model.trainingNet:backward({grayscaleInputs, randomness, RGBTargets, contentTargets, classLabels}, outputLoss)
if superBatch == 1 then
if debugBatchIndices[totalBatchCount] then
torchUtil.dumpGraph(model.trainingNet, opt.outDir .. 'graphDump' .. totalBatchCount .. '.csv')
end
do
local _, predictions = classProbabilities:float():sort(2, true) -- descending
for b = 1, opt.batchSize do
--print(predictions[b][1] .. ' vs ' .. classLabelsCPU[b][1])
if predictions[b][1] == batch.classLabels[b] then
top1 = top1 + 1
end
if predictions[b][1] == batch.classLabels[b] or
predictions[b][2] == batch.classLabels[b] or
predictions[b][3] == batch.classLabels[b] or
predictions[b][4] == batch.classLabels[b] or
predictions[b][5] == batch.classLabels[b] then
top5 = top5 + 1
end
end
top1 = top1 * 100 / opt.batchSize
top5 = top5 * 100 / opt.batchSize
end
end
-- Output test samples
if superBatch == 1 and totalBatchCount % 100 == 0 then
-- Copy image #1 into the entire batch for grayscaleInputs
-- This allows us to output N random samples from the network, where N = batch size
for batchIndex = 2, opt.batchSize do
grayscaleInputs[batchIndex]:copy(grayscaleInputs[1])
end
-- Save ground truth RGB image
local inClone = RGBTargets[1]:clone()
inClone = torchUtil.caffeDeprocess(inClone)
image.save(opt.outDir .. 'samples/iter' .. totalBatchCount .. '_groundTruth.jpg', inClone)
-- Save predicted images
local prediction = model.predictionNet:forward({grayscaleInputs, randomness})
for testSampleIndex = 1, opt.numTestSamples do
local predictionRGB = torchUtil.caffeDeprocess(prediction[testSampleIndex]:clone())
image.save(opt.outDir .. 'samples/iter' .. totalBatchCount .. '_sample' .. testSampleIndex .. '_predictedRGBDebug.jpg', predictionRGB)
local predictionRGB = torchUtil.predictionCorrectedRGB(grayscaleInputs[1], predictionRGB)
image.save(opt.outDir .. 'samples/iter' .. totalBatchCount .. '_sample' .. testSampleIndex .. '_predictedRGB.jpg', predictionRGB)
end
end
end
model.vggNet:zeroGradParameters()
return totalLossSum, gradParameters
end
optim.adam(feval, parameters, optimState)
cutorch.synchronize()
batchNumber = batchNumber + 1
epochStats.total = epochStats.total + totalLossSum
epochStats.class = epochStats.class + classLossSum
epochStats.pixelRGB = epochStats.pixelRGB + pixelRGBLossSum
epochStats.content = epochStats.content + contentLossSum
epochStats.kld = epochStats.kld + kldLossSum
epochStats.top1Accuracy = top1
epochStats.top5Accuracy = top5
print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format(
epoch, batchNumber, opt.epochSize, timer:time().real, totalLossSum,
optimState.learningRate, dataLoadingTime))
print(string.format(' Top 1 accuracy: %f%%', top1))
print(string.format(' Top 5 accuracy: %f%%', top5))
print(string.format(' Class loss: %f', classLossSum))
print(string.format(' RGB loss: %f', pixelRGBLossSum))
print(string.format(' Content loss: %f', contentLossSum))
print(string.format(' KLD loss: %f', kldLossSum))
dataTimer:reset()
totalBatchCount = totalBatchCount + 1
end
-- 4. trainSuperBatch - Used by train() to train a classifier-only superbatch.
local function trainSuperBatchClassifier(model, imgLoader, opt, epoch)
local parameters, gradParameters = model.classifierTrainingNet:getParameters()
cutorch.synchronize()
local dataLoadingTime = 0
timer:reset()
local classLossSum = 0
local top1, top5 = 0, 0
local feval = function(x)
model.classifierTrainingNet:zeroGradParameters()
for superBatch = 1, opt.superBatches do
local loadTimeStart = dataTimer:time().real
local batch = imageLoader.sampleBatch(imgLoader)
local loadTimeEnd = dataTimer:time().real
dataLoadingTime = dataLoadingTime + (loadTimeEnd - loadTimeStart)
-- transfer over to GPU
grayscaleInputs:resize(batch.grayscaleInputs:size()):copy(batch.grayscaleInputs)
classLabels:resize(batch.classLabels:size()):copy(batch.classLabels)
local outputLoss = model.classifierTrainingNet:forward({grayscaleInputs, classLabels})
local classLoss = outputLoss[1]
classLossSum = classLossSum + classLoss
local classProbabilities = model.classifierClassProbabilities.data.module.output
model.classifierTrainingNet:backward({grayscaleInputs, classLabels}, outputLoss)
if superBatch == 1 then
if debugBatchIndices[totalBatchCount] then
torchUtil.dumpGraph(model.classifierTrainingNet, opt.outDir .. 'graphDump' .. totalBatchCount .. '.csv')
end
do
local _, predictions = classProbabilities:float():sort(2, true) -- descending
for b = 1, opt.batchSize do
--print(predictions[b][1] .. ' vs ' .. classLabelsCPU[b][1])
if predictions[b][1] == batch.classLabels[b] then
top1 = top1 + 1
end
if predictions[b][1] == batch.classLabels[b] or
predictions[b][2] == batch.classLabels[b] or
predictions[b][3] == batch.classLabels[b] or
predictions[b][4] == batch.classLabels[b] or
predictions[b][5] == batch.classLabels[b] then
top5 = top5 + 1
end
end
top1 = top1 * 100 / opt.batchSize
top5 = top5 * 100 / opt.batchSize
end
end
end
return classLossSum, gradParameters
end
optim.adam(feval, parameters, optimState)
cutorch.synchronize()
batchNumber = batchNumber + 1
epochStats.class = epochStats.class + classLossSum
epochStats.top1Accuracy = top1
epochStats.top5Accuracy = top5
print(('Epoch: [%d][%d/%d]\tTime %.3f Err %.4f LR %.0e DataLoadingTime %.3f'):format(
epoch, batchNumber, opt.epochSize, timer:time().real, classLossSum,
optimState.learningRate, dataLoadingTime))
print(string.format(' Top 1 accuracy: %f%%', top1))
print(string.format(' Top 5 accuracy: %f%%', top5))
print(string.format(' Class loss: %f', classLossSum))
dataTimer:reset()
totalBatchCount = totalBatchCount + 1
end
-------------------------------------------------------------------------------------------
-- train - this function handles the high-level training loop,
-- i.e. load data, train model, save model and state to disk
local function train(model, imgLoader, opt, epoch)
-- Initialize logging stuff
if trainLogger == nil then
trainLogger = optim.Logger(paths.concat(opt.outDir, 'train.log'))
end
batchNumber = 0
-- save model
--this should happen at the end of training, but we keep breaking save so I put it first.
collectgarbage()
-- clear the intermediate states in the model before saving to disk
-- this saves lots of disk space
--model.trainingNet:clearState()
model.encoder:clearState()
model.decoder:clearState()
model.classifier:clearState()
model.vggNet:clearState()
torch.save(opt.outDir .. 'models/transform' .. epoch .. '.t7', model.trainingNet)
print('==> doing epoch on training data:')
print("==> online epoch # " .. epoch)
local params, newRegime = paramsForEpoch(epoch)
if newRegime then
optimState = {
learningRate = params.learningRate,
weightDecay = params.weightDecay
}
end
cutorch.synchronize()
-- set the dropouts to training mode
model.classifierTrainingNet:training()
model.trainingNet:training()
local tm = torch.Timer()
epochStats.total = 0
epochStats.class = 0
epochStats.pixelRGB = 0
epochStats.content = 0
epochStats.kld = 0
for i = 1, opt.epochSize do
if opt.classifierOnly then
trainSuperBatchClassifier(model, imgLoader, opt, epoch)
else
trainSuperBatch(model, imgLoader, opt, epoch)
end
end
cutorch.synchronize()
local scaleFactor = 1.0 / (opt.batchSize * opt.superBatches * opt.epochSize)
epochStats.total = epochStats.total * scaleFactor
epochStats.class = epochStats.class * scaleFactor
epochStats.pixelRGB = epochStats.pixelRGB * scaleFactor
epochStats.content = epochStats.content * scaleFactor
epochStats.kld = epochStats.kld * scaleFactor
trainLogger:add{
['total loss (train set)'] = epochStats.total,
['class loss (train set)'] = epochStats.class,
['RGB loss (train set)'] = epochStats.pixelRGB,
['content loss (train set)'] = epochStats.content,
['KLD loss (train set)'] = epochStats.kld,
}
print(string.format('Epoch: [%d][TRAINING SUMMARY] Total Time(s): %.2f\t'
.. 'average loss (per batch): %.2f \t '
.. 'accuracy(%%):\t top-1 %.2f\t',
epoch, tm:time().real, epochStats.total, epochStats.total))
print('\n')
end
-------------------------------------------------------------------------------------------
return train