Skip to content

Commit 9cffea5

Browse files
apaszkesoumith
authored andcommitted
Remove unnecessary function override in unpooling modules (#749)
Doing so breaks deserialization on systems with different architectures. These operations are cheap, so it's not a problem to perform them even when there's no unpooling associated.
1 parent 130955e commit 9cffea5

4 files changed

+11
-16
lines changed

SpatialMaxPooling.lua

+5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ end
3030

3131
function SpatialMaxPooling:updateOutput(input)
3232
self.indices = self.indices or input.new()
33+
34+
local dims = input:dim()
35+
self.iheight = input:size(dims-1)
36+
self.iwidth = input:size(dims)
37+
3338
-- backward compatibility
3439
self.ceil_mode = self.ceil_mode or false
3540
self.padW = self.padW or 0

SpatialMaxUnpooling.lua

+1-8
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,9 @@ local SpatialMaxUnpooling, parent = torch.class('nn.SpatialMaxUnpooling', 'nn.Mo
22

33
function SpatialMaxUnpooling:__init(poolingModule)
44
parent.__init(self)
5-
assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SPatialMaxPooling module')
5+
assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SpatialMaxPooling module')
66
assert(poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride")
77
self.pooling = poolingModule
8-
9-
poolingModule.updateOutput = function(pool, input)
10-
local dims = input:dim()
11-
pool.iheight = input:size(dims-1)
12-
pool.iwidth = input:size(dims)
13-
return nn.SpatialMaxPooling.updateOutput(pool, input)
14-
end
158
end
169

1710
function SpatialMaxUnpooling:setParams()

VolumetricMaxPooling.lua

+5
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ function VolumetricMaxPooling:floor()
3636
end
3737

3838
function VolumetricMaxPooling:updateOutput(input)
39+
local dims = input:dim()
40+
self.itime = input:size(dims-2)
41+
self.iheight = input:size(dims-1)
42+
self.iwidth = input:size(dims)
43+
3944
self.indices = self.indices or input.new()
4045
input.THNN.VolumetricMaxPooling_updateOutput(
4146
input:cdata(),

VolumetricMaxUnpooling.lua

-8
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@ function VolumetricMaxUnpooling:__init(poolingModule)
55
assert(torch.type(poolingModule)=='nn.VolumetricMaxPooling', 'Argument must be a nn.VolumetricMaxPooling module')
66
assert(poolingModule.kT==poolingModule.dT and poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride")
77
self.pooling = poolingModule
8-
9-
poolingModule.updateOutput = function(pool, input)
10-
local dims = input:dim()
11-
pool.itime = input:size(dims-2)
12-
pool.iheight = input:size(dims-1)
13-
pool.iwidth = input:size(dims)
14-
return nn.VolumetricMaxPooling.updateOutput(pool, input)
15-
end
168
end
179

1810
function VolumetricMaxUnpooling:setParams()

0 commit comments

Comments
 (0)