-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_bg.lua
201 lines (177 loc) · 5.46 KB
/
utils_bg.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
function stringToByteTensor(str, vecLength)
-- str: string, table of strings, or tensor of ascii codes
-- vecLength: vector length
-- output: 2D ASCII byte tensor
local strByteTensor = torch.ByteTensor()
if type(str) == 'string' then
str = {str}
end
torch.IntTensor().nn.Utils_stringToByteTensor(str, vecLength, strByteTensor)
return strByteTensor
end
function stringToVec(str, vecLength)
-- str: string, table of strings, or tensor of ascii codes
-- vecLength: vector length
-- strVec: 2D-tensor of vectorized strings
local strVec = torch.IntTensor()
if type(str) == 'string' or type(str) == 'table' then
if type(str) == 'string' then
str = {str}
end
strVec.nn.Utils_stringToVec(str, vecLength, strVec)
elseif str:type() == 'torch.ByteTensor' then
strVec.nn.Utils_byteTensorToVec(str, vecLength, strVec)
else
error('Unrecognized input type')
end
return strVec
end
function vecToString(vec)
if vec:dim() == 1 then
vec = vec:view(1, vec:size(1))
end
local strings = vec.nn.Utils_vecToString(vec)
return strings
end
function oneHotEmbedding(target, yDim, pruneT)
-- target: [n x maxT] IntTensor, input label sequences
-- embedding: [n x T' x Dy] Tensor, output embedded tensor
pruneT = pruneT or false
local n, T = target:size(1), target:size(2)
local embedding = torch.Tensor(n, T, yDim):fill(0)
local maxT = 0
for i = 1, n do
for t = 1, T do
local label = target[i][t]
if label ~= 0 then
embedding[i][t][label] = 1.0
maxT = math.max(maxT, t)
end
end
end
if pruneT then
embedding = embedding:narrow(2, 1, maxT):clone()
end
return embedding
end
function setupLogger(fpath)
local fileMode = 'w'
if paths.filep(fpath) then
local input = nil
while not input do
print('Logging file exits, overwrite(o)? append(a)? abort(q)?')
input = io.read()
if input == 'o' then
fileMode = 'w'
elseif input == 'a' then
fileMode = 'a'
elseif input == 'q' then
os.exit()
else
fileMode = nil
end
end
end
gLoggerFile = io.open(fpath, fileMode)
end
function tensorInfo(x, name)
local name = name or ''
local sizeStr = ''
for i = 1, #x:size() do
sizeStr = sizeStr .. string.format('%d', x:size(i))
if i < #x:size() then
sizeStr = sizeStr .. 'x'
end
end
infoStr = string.format('[%15s] size: %12s, min: %+.2e, max: %+.2e', name, sizeStr, x:min(), x:max())
return infoStr
end
function shutdownLogger()
if gLoggerFile then
gLoggerFile:close()
end
end
function logging(message, mute)
mute = mute or false
local timeStamp = os.date('%x %X')
local msgFormatted = string.format('[%s] %s', timeStamp, message)
if not mute then
print(msgFormatted)
end
if gLoggerFile then
gLoggerFile:write(msgFormatted .. '\n')
gLoggerFile:flush()
end
end
function modelSize(model)
-- calculate the number of parameters in a model
local params = model:parameters()
local count = 0
local countForEach = {}
for i = 1, #params do
local nParam = params[i]:numel()
count = count + nParam
countForEach[i] = nParam
end
return count, torch.LongTensor(countForEach)
end
function cloneList(tensorList, setZero)
local out = {}
for k, v in pairs(tensorList) do
out[k] = v:clone()
if setZero then out[k]:zero() end
end
return out
end
function cloneManyTimes(module, T)
local clones = {}
local params, gradParams = module:parameters()
local mem = torch.MemoryFile("w"):binary()
mem:writeObject(module)
for t = 1, T do
local reader = torch.MemoryFile(mem:storage(), "r"):binary()
local clone = reader:readObject()
reader:close()
local cloneParams, cloneGradParams = clone:parameters()
for i = 1, #params do
cloneParams[i]:set(params[i])
cloneGradParams[i]:set(gradParams[i])
end
clones[t] = clone
collectgarbage()
end
mem:close()
return clones
end
function diagnoseGradients(params, gradParams)
for i = 1, #params do
local pMin = params[i]:min()
local pMax = params[i]:max()
local gpMin = gradParams[i]:min()
local gpMax = gradParams[i]:max()
local normRatio = gradParams[i]:norm() / params[i]:norm()
logging(string.format('%02d - params [%+.2e, %+.2e] gradParams [%+.2e, %+.2e], norm gp/p %+.2e',
i, pMin, pMax, gpMin, gpMax, normRatio), true)
end
end
function dumpModelState(model)
local state = model:parameters()
local bnLayers = model:findModules('nn.BatchNormalization')
for i = 1, #bnLayers do
table.insert(state, bnLayers[i].running_mean)
table.insert(state, bnLayers[i].running_std)
end
local sbnLayers = model:findModules('nn.SpatialBatchNormalization')
for i = 1, #sbnLayers do
table.insert(state, sbnLayers[i].running_mean)
table.insert(state, sbnLayers[i].running_std)
end
return state
end
function loadModelState(model, stateToLoad)
local state = dumpModelState(model)
assert(#state == #stateToLoad)
for i = 1, #state do
state[i]:copy(stateToLoad[i])
end
end