|
| 1 | +local SpatialFullConvolutionMap, parent = torch.class('nn.SpatialFullConvolutionMap', 'nn.Module') |
| 2 | + |
| 3 | +function SpatialFullConvolutionMap:__init(conMatrix, kW, kH, dW, dH) |
| 4 | + parent.__init(self) |
| 5 | + |
| 6 | + dW = dW or 1 |
| 7 | + dH = dH or 1 |
| 8 | + |
| 9 | + self.kW = kW |
| 10 | + self.kH = kH |
| 11 | + self.dW = dW |
| 12 | + self.dH = dH |
| 13 | + self.connTable = conMatrix |
| 14 | + self.nInputPlane = self.connTable:select(2,1):max() |
| 15 | + self.nOutputPlane = self.connTable:select(2,2):max() |
| 16 | + |
| 17 | + self.weight = torch.Tensor(self.connTable:size(1), kH, kW) |
| 18 | + self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW) |
| 19 | + |
| 20 | + self.bias = torch.Tensor(self.nOutputPlane) |
| 21 | + self.gradBias = torch.Tensor(self.nOutputPlane) |
| 22 | + |
| 23 | + self:reset() |
| 24 | +end |
| 25 | + |
| 26 | +function SpatialFullConvolutionMap:reset(stdv) |
| 27 | + if stdv then |
| 28 | + stdv = stdv * math.sqrt(3) |
| 29 | + self.weight:apply(function() |
| 30 | + return torch.uniform(-stdv, stdv) |
| 31 | + end) |
| 32 | + self.bias:apply(function() |
| 33 | + return torch.uniform(-stdv, stdv) |
| 34 | + end) |
| 35 | + else |
| 36 | + local ninp = torch.Tensor(self.nOutputPlane):zero() |
| 37 | + for i=1,self.connTable:size(1) do ninp[self.connTable[i][2]] = ninp[self.connTable[i][2]]+1 end |
| 38 | + for k=1,self.connTable:size(1) do |
| 39 | + stdv = 1/math.sqrt(self.kW*self.kH*ninp[self.connTable[k][2]]) |
| 40 | + self.weight:select(1,k):apply(function() return torch.uniform(-stdv,stdv) end) |
| 41 | + end |
| 42 | + for k=1,self.bias:size(1) do |
| 43 | + stdv = 1/math.sqrt(self.kW*self.kH*ninp[k]) |
| 44 | + self.bias[k] = torch.uniform(-stdv,stdv) |
| 45 | + end |
| 46 | + |
| 47 | + end |
| 48 | +end |
| 49 | + |
| 50 | +function SpatialFullConvolutionMap:updateOutput(input) |
| 51 | + input.nn.SpatialFullConvolutionMap_updateOutput(self, input) |
| 52 | + return self.output |
| 53 | +end |
| 54 | + |
| 55 | +function SpatialFullConvolutionMap:updateGradInput(input, gradOutput) |
| 56 | + input.nn.SpatialFullConvolutionMap_updateGradInput(self, input, gradOutput) |
| 57 | + return self.gradInput |
| 58 | +end |
| 59 | + |
| 60 | +function SpatialFullConvolutionMap:accGradParameters(input, gradOutput, scale) |
| 61 | + return input.nn.SpatialFullConvolutionMap_accGradParameters(self, input, gradOutput, scale) |
| 62 | +end |
0 commit comments