-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtester.lua
181 lines (136 loc) · 4.65 KB
/
tester.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
local opt = lapp [[
Train a CNN classifier on CIFAR-10 using Asynchronous.
--numNodes (default 1) num nodes spawned in parallel
--batchSize (default 32) batch size, per node
--numEpochs (default inf) Total Number of epochs
--cuda use cuda
--gpu (default 1) which gpu to use (only when using cuda)
--host (default '127.0.0.1') host name of the server
--port (default 8080) port number of the server
--base (default 2) power of 2 base of the tree of nodes
--clientIP (default '127.0.0.1') host name of the client
--server Client/Server
--tester Tester
--verbose Print Communication details
--save (default 'log') Save location
--visualize Visualization
]]
-- Requires
if opt.cuda then
require 'cutorch'
require 'cunn'
cutorch.setDevice(opt.gpu)
end
local grad = require 'autograd'
local util = require 'autograd.util'
local lossFuns = require 'autograd.loss'
local optim = require 'optim'
local Dataset = require 'dataset.Dataset'
if opt.save == 'log' then
opt.save = os.date():gsub(' ','')
end
opt.save = paths.concat('./Results', opt.save)
os.execute('mkdir -p ' .. opt.save)
local cmd = torch.CmdLine()
cmd:log(opt.save .. '/Log.txt', opt)
local netFilename = paths.concat(opt.save, 'Net')
local logFilename = paths.concat(opt.save,'ErrorRate.log')
local optStateFilename = paths.concat(opt.save,'optState')
local Log = optim.Logger(logFilename)
require 'colorPrint' -- Print Server and Client in colors
-- if not verbose
if not opt.verbose then
function printServer(string) end
function printClient(string) end
end
-- Build the Network
local ipc = require 'libipc'
local Tree = require 'ipc.Tree'
local client, server
local clientTest, serverTest
local serverBroadcast, clientBroadcast
-- initialize tester
clientTest = ipc.client(opt.host, opt.port + opt.numNodes + 1)
local AsyncEA = require 'Async-EASGD.AsyncEA'(server, serverBroadcast, client, clientBroadcast, serverTest, clientTest, opt.numNodes, 1,10, 0.2)
-- Print only in server and tester nodes!
if not (opt.tester or opt.server) then
xlua.progress = function() end
print = function() end
end
-- Load dataset
local data = require 'Data'(1, 1, opt.batchSize, opt.cuda)
local getTrainingBatch = data.getTrainingBatch
local numTrainingBatches = data.numTrainingBatches
local getTestBatch = data.getTestBatch
local numTestBatches = data.numTestBatches
local classes = data.classes
local confusionMatrix = optim.ConfusionMatrix(classes)
local Model = require 'Model'
local params = Model.params
local f = Model.f
local df = Model.df
-- Cast the parameters
params = grad.util.cast(params, opt.cuda and 'cuda' or 'float')
AsyncEA.initTester(params)
local epoch = 1
-- Train a neural network
if opt.numEpochs == 'inf' then
opt.numEpochs = 1/0
end
while true do
local terminationFlag
terminationFlag = AsyncEA.startTest(params)
if terminationFlag then
break
end
-- Check Training Error
print('\nTraining Error Trial #'..epoch .. '\n')
for i = 1,numTrainingBatches() do
-- Next sample:
local batch = getTrainingBatch()
local x = batch.input
local y = batch.target
-- Prediction:
local loss, prediction = f(params,x,y)
-- Log performance:
for b = 1,batch.batchSize do
confusionMatrix:add(prediction[b], y[b])
end
-- Display progress:
xlua.progress(i, numTestBatches())
end
print(confusionMatrix)
local ErrTrain = (1-confusionMatrix.totalValid)
print('Training Error = ' .. ErrTrain)
confusionMatrix:zero()
-- Check Test Error
print('\nTesting Error Trial #' ..epoch .. '\n')
for i = 1,numTestBatches() do
-- Next sample:
local batch = getTestBatch()
local x = batch.input
local y = batch.target
-- Prediction:
local loss, prediction = f(params,x,y)
-- Log performance:
for b = 1,batch.batchSize do
confusionMatrix:add(prediction[b], y[b])
end
-- Display progress:
xlua.progress(i, numTestBatches())
end
print(confusionMatrix)
local ErrTest = (1-confusionMatrix.totalValid)
print('Test Error = ' .. ErrTest .. '\n')
confusionMatrix:zero()
Log:add{['Training Error']= ErrTrain, ['Test Error'] = ErrTest}
torch.save(opt.save .. '/net.t7',params)
if opt.visualize then
Log:style{['Training Error'] = '-', ['Test Error'] = '-'}
Log:plot()
end
epoch = epoch + 1
AsyncEA.finishTest()
end
-- Close Connection
clientTest:close()