-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathsplit_network.lua
67 lines (61 loc) · 1.53 KB
/
split_network.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
require 'torch'
require 'nn'
require 'optim'
require 'paths'
assert(pcall(function () mat = require('fb.mattorch') end) or pcall(function() mat = require('matio') end), 'no mat IO interface available')
opt = {
nz = 200,
name = 'shapenet101',
ext = 'net_C',
checkpointd = '/data/jjliu/checkpoints/',
checkpointf='checkpoints_64class100',
epoch=25,
gpu=1,
parallel=false,
splitIndex=7,
removeDropout=false,
}
for k,v in pairs(opt) do
opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k]
--print(k .. ': ' .. opt[k])
end
if opt.gpu > 0 then
require 'cunn'
require 'cudnn'
require 'cutorch'
cutorch.setDevice(opt.gpu)
end
local net = torch.load(paths.concat(opt.checkpointd .. opt.checkpointf, opt.name .. '_' .. opt.epoch .. '_' .. opt.ext .. '.t7'))
net = net:clone()
if opt.parallel then
net = net:get(1)
end
print(net)
assert(opt.splitIndex <= net:size())
num2remove = net:size() - opt.splitIndex
for i = 1, num2remove do
net:remove()
end
if opt.removeDropout then
while true do
done = true
for i = 1, net:size() do
local name = torch.type(net:get(i))
local drop_p = net:get(i).p
print(name)
if name:find('Dropout') then
net:remove(i)
net:insert(nn.Mul(1-drop_p), i)
done = false
end
end
if done then
break
end
end
end
print(net)
print(net:size())
out_path = paths.concat(opt.checkpointd .. opt.checkpointf, opt.name .. '_' .. opt.epoch .. '_' .. opt.ext .. '_split' .. opt.splitIndex .. '.t7')
torch.save(out_path, net)
print(out_path)