Skip to content

Commit 1efff4d

Browse files
committed
refactoring all the common container code into nn.Container
1 parent a38407a commit 1efff4d

File tree

6 files changed

+86
-176
lines changed

6 files changed

+86
-176
lines changed

Concat.lua

+2-64
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
1-
local Concat, parent = torch.class('nn.Concat', 'nn.Module')
1+
local Concat, parent = torch.class('nn.Concat', 'nn.Container')
22

33
function Concat:__init(dimension)
4-
parent.__init(self)
5-
self.modules = {}
4+
parent.__init(self, dimension)
65
self.size = torch.LongStorage()
76
self.dimension = dimension
87
end
98

10-
function Concat:add(module)
11-
table.insert(self.modules, module)
12-
return self
13-
end
14-
15-
function Concat:get(index)
16-
return self.modules[index]
17-
end
18-
199
function Concat:updateOutput(input)
2010
local outs = {}
2111
for i=1,#self.modules do
@@ -83,58 +73,6 @@ function Concat:accUpdateGradParameters(input, gradOutput, lr)
8373
end
8474
end
8575

86-
function Concat:zeroGradParameters()
87-
for _,module in ipairs(self.modules) do
88-
module:zeroGradParameters()
89-
end
90-
end
91-
92-
function Concat:updateParameters(learningRate)
93-
for _,module in ipairs(self.modules) do
94-
module:updateParameters(learningRate)
95-
end
96-
end
97-
98-
function Concat:training()
99-
for i=1,#self.modules do
100-
self.modules[i]:training()
101-
end
102-
end
103-
104-
function Concat:evaluate()
105-
for i=1,#self.modules do
106-
self.modules[i]:evaluate()
107-
end
108-
end
109-
110-
function Concat:share(mlp,...)
111-
for i=1,#self.modules do
112-
self.modules[i]:share(mlp.modules[i],...);
113-
end
114-
end
115-
116-
function Concat:parameters()
117-
local function tinsert(to, from)
118-
if type(from) == 'table' then
119-
for i=1,#from do
120-
tinsert(to,from[i])
121-
end
122-
else
123-
table.insert(to,from)
124-
end
125-
end
126-
local w = {}
127-
local gw = {}
128-
for i=1,#self.modules do
129-
local mw,mgw = self.modules[i]:parameters()
130-
if mw then
131-
tinsert(w,mw)
132-
tinsert(gw,mgw)
133-
end
134-
end
135-
return w,gw
136-
end
137-
13876
function Concat:__tostring__()
13977
local tab = ' '
14078
local line = '\n'

Container.lua

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
-- This is code common to container modules, which are collections of
2+
-- smaller constituent modules like Parallel, Sequential, etc.
3+
local Container, parent =
4+
torch.class('nn.Container', 'nn.Module')
5+
6+
function Container:__init(...)
7+
parent.__init(self, ...)
8+
self.modules = {}
9+
end
10+
11+
function Container:add(module)
12+
table.insert(self.modules, module)
13+
return self
14+
end
15+
16+
function Container:get(index)
17+
return self.modules[index]
18+
end
19+
20+
function Container:size()
21+
return #self.modules
22+
end
23+
24+
function Container:zeroGradParameters()
25+
for i=1,#self.modules do
26+
self.modules[i]:zeroGradParameters()
27+
end
28+
end
29+
30+
function Container:updateParameters(learningRate)
31+
for _,module in ipairs(self.modules) do
32+
module:updateParameters(learningRate)
33+
end
34+
end
35+
36+
function Container:training()
37+
for i=1,#self.modules do
38+
self.modules[i]:training()
39+
end
40+
end
41+
42+
function Container:evaluate()
43+
for i=1,#self.modules do
44+
self.modules[i]:evaluate()
45+
end
46+
end
47+
48+
function Container:share(mlp, ...)
49+
for i=1,#self.modules do
50+
self.modules[i]:share(mlp.modules[i], ...);
51+
end
52+
end
53+
54+
function Container:reset(stdv)
55+
for i=1,#self.modules do
56+
self.modules[i]:reset(stdv)
57+
end
58+
end
59+
60+
function Container:parameters()
61+
local function tinsert(to, from)
62+
if type(from) == 'table' then
63+
for i=1,#from do
64+
tinsert(to,from[i])
65+
end
66+
else
67+
table.insert(to,from)
68+
end
69+
end
70+
local w = {}
71+
local gw = {}
72+
for i=1,#self.modules do
73+
local mw,mgw = self.modules[i]:parameters()
74+
if mw then
75+
tinsert(w,mw)
76+
tinsert(gw,mgw)
77+
end
78+
end
79+
return w,gw
80+
end

Parallel.lua

+1-40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
local Parallel, parent = torch.class('nn.Parallel', 'nn.Module')
1+
local Parallel, parent = torch.class('nn.Parallel', 'nn.Container')
22

33
function Parallel:__init(inputDimension,outputDimension)
44
parent.__init(self)
@@ -8,15 +8,6 @@ function Parallel:__init(inputDimension,outputDimension)
88
self.outputDimension = outputDimension
99
end
1010

11-
function Parallel:add(module)
12-
table.insert(self.modules, module)
13-
return self
14-
end
15-
16-
function Parallel:get(index)
17-
return self.modules[index]
18-
end
19-
2011
function Parallel:updateOutput(input)
2112

2213
local modules=input:size(self.inputDimension)
@@ -99,36 +90,6 @@ function Parallel:accUpdateGradParameters(input, gradOutput, lr)
9990
end
10091
end
10192

102-
function Parallel:zeroGradParameters()
103-
for _,module in ipairs(self.modules) do
104-
module:zeroGradParameters()
105-
end
106-
end
107-
108-
function Parallel:updateParameters(learningRate)
109-
for _,module in ipairs(self.modules) do
110-
module:updateParameters(learningRate)
111-
end
112-
end
113-
114-
function Parallel:training()
115-
for i=1,#self.modules do
116-
self.modules[i]:training()
117-
end
118-
end
119-
120-
function Parallel:evaluate()
121-
for i=1,#self.modules do
122-
self.modules[i]:evaluate()
123-
end
124-
end
125-
126-
function Parallel:share(mlp,...)
127-
for i=1,#self.modules do
128-
self.modules[i]:share(mlp.modules[i],...);
129-
end
130-
end
131-
13293
function Parallel:parameters()
13394
local function tinsert(to, from)
13495
if type(from) == 'table' then

Sequential.lua

+1-71
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
local Sequential, parent = torch.class('nn.Sequential', 'nn.Module')
2-
3-
function Sequential:__init()
4-
parent.__init(self)
5-
self.modules = {}
6-
end
1+
local Sequential, _ = torch.class('nn.Sequential', 'nn.Container')
72

83
function Sequential:add(module)
94
if #self.modules == 0 then
@@ -24,14 +19,6 @@ function Sequential:insert(module, index)
2419
self.gradInput = self.modules[1].gradInput
2520
end
2621

27-
function Sequential:size()
28-
return #self.modules
29-
end
30-
31-
function Sequential:get(index)
32-
return self.modules[index]
33-
end
34-
3522
function Sequential:updateOutput(input)
3623
local currentOutput = input
3724
for i=1,#self.modules do
@@ -82,63 +69,6 @@ function Sequential:accUpdateGradParameters(input, gradOutput, lr)
8269
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
8370
end
8471

85-
function Sequential:zeroGradParameters()
86-
for i=1,#self.modules do
87-
self.modules[i]:zeroGradParameters()
88-
end
89-
end
90-
91-
function Sequential:updateParameters(learningRate)
92-
for i=1,#self.modules do
93-
self.modules[i]:updateParameters(learningRate)
94-
end
95-
end
96-
97-
function Sequential:training()
98-
for i=1,#self.modules do
99-
self.modules[i]:training()
100-
end
101-
end
102-
103-
function Sequential:evaluate()
104-
for i=1,#self.modules do
105-
self.modules[i]:evaluate()
106-
end
107-
end
108-
109-
function Sequential:share(mlp,...)
110-
for i=1,#self.modules do
111-
self.modules[i]:share(mlp.modules[i],...);
112-
end
113-
end
114-
115-
function Sequential:reset(stdv)
116-
for i=1,#self.modules do
117-
self.modules[i]:reset(stdv)
118-
end
119-
end
120-
121-
function Sequential:parameters()
122-
local function tinsert(to, from)
123-
if type(from) == 'table' then
124-
for i=1,#from do
125-
tinsert(to,from[i])
126-
end
127-
else
128-
table.insert(to,from)
129-
end
130-
end
131-
local w = {}
132-
local gw = {}
133-
for i=1,#self.modules do
134-
local mw,mgw = self.modules[i]:parameters()
135-
if mw then
136-
tinsert(w,mw)
137-
tinsert(gw,mgw)
138-
end
139-
end
140-
return w,gw
141-
end
14272

14373
function Sequential:__tostring__()
14474
local tab = ' '

init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ require('libnn')
44
include('ErrorMessages.lua')
55
include('Module.lua')
66

7+
include('Container.lua')
78
include('Concat.lua')
89
include('Parallel.lua')
910
include('Sequential.lua')

test.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ function nntest.WeightedEuclidean()
477477
local inj = math.random(13,5)
478478
local input = torch.Tensor(ini):zero()
479479
local module = nn.WeightedEuclidean(ini,inj)
480-
480+
481481
local err = jac.testJacobian(module,input)
482482
mytester:assertlt(err,precision, 'error on state ')
483483

0 commit comments

Comments
 (0)