-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.lua
148 lines (121 loc) · 3.64 KB
/
trainer.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
require 'paths'
require 'optim'
local maker = require 'data_maker'
local tablex = require 'pl.tablex'
local Trainer = torch.class('Trainer')
function Trainer:__init(model)
self.model = model
self.params, self.grad_params = self.model:parameters()
end
function Trainer:train(epoch, train, opt)
local timer = torch.Timer()
local tIter = 0
local tLoss = {}
train:reset()
if opt.curriculum < epoch then
train:shuffle()
end
local nbOfshard = train:nshard()
for i = 1, nbOfshard do
local shard = train:next()
if opt.cuda then shard:cuda() end
if opt.reverse then shard:reverse() end
local nbOfbatch = shard:nbatch()
local lOfbatch = shard:lbatch()
local nbOfnonzero = shard:nonzero()
for j = 1, nbOfbatch do
timer:reset()
local feval = self.model:trainb(opt, unpack(shard[j]))
local _, loss = optim[opt.optim](feval, self.params,self.optim_config)
tLoss[#tLoss + 1] = loss[1] * lOfbatch[j] / nbOfnonzero[j]
tIter = tIter + 1
if tIter % opt.nprint == 0 then
print(string.format(
'%3d/%d/%d/%d (epoch %d), err = %6.4e, grad = %6.4e, time = %.4fs',
j, nbOfbatch, i, nbOfshard, epoch, tLoss[#tLoss],
self.grad_params:norm() / self.params:norm(),
timer:time().real
))
end
end
end
local loss = tablex.reduce('+', tLoss)
loss = loss / tIter
return loss
end
function Trainer:eval(epoch, valid, opt)
local timer = torch.Timer()
local vIter = 0
local vLoss = {}
valid:reset()
local nbOfshard = valid:nshard()
for i = 1, nbOfshard do
local shard = valid:next()
if opt.cuda then shard:cuda() end
if opt.reverse then shard:reverse() end
local nbOfbatch = shard:nbatch()
local lOfbatch = shard:lbatch()
local nbOfnonzero = shard:nonzero()
for j = 1, nbOfbatch do
timer:reset()
local loss = self.model:evalb(opt, unpack(shard[j]))
vLoss[#vLoss + 1] = loss * lOfbatch[j] / nbOfnonzero[j]
vIter = vIter + 1
end
end
local loss = tablex.reduce('+', vLoss)
loss = loss / vIter
return loss
end
function Trainer:run(train, valid, opt)
local tLosses = {}
local vLosses = {}
local shrink_factor = opt.shrink_factor
local shrink_multiplier = opt.shrink_multiplier
if not paths.dirp(opt.save) then
os.execute('mkdir -p ' .. opt.save)
end
local timer = torch.Timer()
self.optim_config = {learningRate = opt.learningRate}
for i = 1, opt.nepoch do
timer:reset()
local lr = self.optim_config.learningRate
local tLoss = self:train(i, train, opt)
print(string.format(
'=>[epoch %d] training loss = %6.4e, lr = %.4f, time = %.4fs',
i, tLoss, lr, timer:time().real
))
timer:reset()
collectgarbage()
local vLoss = self:eval(i, valid, opt)
print(string.format(
'=>[epoch %d] valid loss = %6.4e, time = %.4fs',
i, vLoss, timer:time().real
))
collectgarbage()
local name = string.format(
'model-%s-%s-epoch%.2f-t%.4e-v%.4e-%s.t7',
opt.name, torch.type(self.model), i, tLoss, vLoss,
os.date('%Y%m%d[%H%M]')
)
self.model:save(paths.concat(opt.save, name))
if opt.optim == 'sgd' and #vLosses > 1 and
vLosses[#vLosses] > vLoss * opt.shrink_multiplier
then
lr = lr / opt.shrink_factor
lr = math.max(lr, opt.minLearningRate)
self.optim_config.learningRate = lr
end
if opt.anneal and i > opt.start_epoch then
lr = lr - (opt.learningRate - opt.minLearningRate) / opt.saturate_epoch
lr = math.max(lr, opt.minLearningRate)
self.optim_config.learningRate = lr
end
tLosses[#tLosses + 1] = tLoss
vLosses[#vLosses + 1] = vLoss
end
local name = string.format('loss-%s-%s-nepoch%.2f-%s.t7',
opt.name, torch.type(self.model), opt.nepoch, os.date('%Y%m%d[%H%M]')
)
torch.save(paths.concat(opt.save, name), {tLosses, vLosses})
end