Skip to content

Commit 1ee64da

Browse files
committed
New NN classes
extra/nn/L1Cost.lua : L1 penalty extra/nn/SpatialFullConvolution.lua : full convolution extra/nn/SpatialFullConvolutionMap.lua : full convolution with connection table extra/nn/TanhShrink.lua : shrinkage with x-tanh(x) extra/nn/WeightedMSECriterion.lua : mean squared error with weighting mask on the target Add new nn classes that are used commonly for unsupervised training of convolutional auto encoders
1 parent 9e6f3d3 commit 1ee64da

13 files changed

+923
-4
lines changed

L1Cost.lua

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
local L1Cost, parent = torch.class('nn.L1Cost','nn.Criterion')
2+
3+
function L1Cost:__init()
4+
parent.__init(self)
5+
end
6+
7+
function L1Cost:updateOutput(input)
8+
return input.nn.L1Cost_updateOutput(self,input)
9+
end
10+
11+
function L1Cost:updateGradInput(input)
12+
return input.nn.L1Cost_updateGradInput(self,input)
13+
end
14+

SpatialFullConvolution.lua

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
local SpatialFullConvolution, parent = torch.class('nn.SpatialFullConvolution','nn.Module')
2+
3+
function SpatialFullConvolution:__init(nInputPlane, nOutputPlane, kW, kH, dW, dH)
4+
parent.__init(self)
5+
6+
dW = dW or 1
7+
dH = dH or 1
8+
9+
self.nInputPlane = nInputPlane
10+
self.nOutputPlane = nOutputPlane
11+
self.kW = kW
12+
self.kH = kH
13+
self.dW = dW
14+
self.dH = dH
15+
16+
self.weight = torch.Tensor(nInputPlane, nOutputPlane, kH, kW)
17+
self.gradWeight = torch.Tensor(nInputPlane, nOutputPlane, kH, kW)
18+
self.bias = torch.Tensor(self.nOutputPlane)
19+
self.gradBias = torch.Tensor(self.nOutputPlane)
20+
21+
self:reset()
22+
end
23+
24+
function SpatialFullConvolution:reset(stdv)
25+
if stdv then
26+
stdv = stdv * math.sqrt(3)
27+
else
28+
local nInputPlane = self.nInputPlane
29+
local kH = self.kH
30+
local kW = self.kW
31+
stdv = 1/math.sqrt(kW*kH*nInputPlane)
32+
end
33+
self.weight:apply(function()
34+
return torch.uniform(-stdv, stdv)
35+
end)
36+
self.bias:apply(function()
37+
return torch.uniform(-stdv, stdv)
38+
end)
39+
end
40+
41+
function SpatialFullConvolution:updateOutput(input)
42+
return input.nn.SpatialFullConvolution_updateOutput(self, input)
43+
end
44+
45+
function SpatialFullConvolution:updateGradInput(input, gradOutput)
46+
if self.gradInput then
47+
return input.nn.SpatialFullConvolution_updateGradInput(self, input, gradOutput)
48+
end
49+
end
50+
function SpatialFullConvolution:accGradParameters(input, gradOutput, scale)
51+
return input.nn.SpatialFullConvolution_accGradParameters(self, input, gradOutput, scale)
52+
end
53+

SpatialFullConvolutionMap.lua

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
local SpatialFullConvolutionMap, parent = torch.class('nn.SpatialFullConvolutionMap', 'nn.Module')
2+
3+
function SpatialFullConvolutionMap:__init(conMatrix, kW, kH, dW, dH)
4+
parent.__init(self)
5+
6+
dW = dW or 1
7+
dH = dH or 1
8+
9+
self.kW = kW
10+
self.kH = kH
11+
self.dW = dW
12+
self.dH = dH
13+
self.connTable = conMatrix
14+
self.nInputPlane = self.connTable:select(2,1):max()
15+
self.nOutputPlane = self.connTable:select(2,2):max()
16+
17+
self.weight = torch.Tensor(self.connTable:size(1), kH, kW)
18+
self.gradWeight = torch.Tensor(self.connTable:size(1), kH, kW)
19+
20+
self.bias = torch.Tensor(self.nOutputPlane)
21+
self.gradBias = torch.Tensor(self.nOutputPlane)
22+
23+
self:reset()
24+
end
25+
26+
function SpatialFullConvolutionMap:reset(stdv)
27+
if stdv then
28+
stdv = stdv * math.sqrt(3)
29+
self.weight:apply(function()
30+
return torch.uniform(-stdv, stdv)
31+
end)
32+
self.bias:apply(function()
33+
return torch.uniform(-stdv, stdv)
34+
end)
35+
else
36+
local ninp = torch.Tensor(self.nOutputPlane):zero()
37+
for i=1,self.connTable:size(1) do ninp[self.connTable[i][2]] = ninp[self.connTable[i][2]]+1 end
38+
for k=1,self.connTable:size(1) do
39+
stdv = 1/math.sqrt(self.kW*self.kH*ninp[self.connTable[k][2]])
40+
self.weight:select(1,k):apply(function() return torch.uniform(-stdv,stdv) end)
41+
end
42+
for k=1,self.bias:size(1) do
43+
stdv = 1/math.sqrt(self.kW*self.kH*ninp[k])
44+
self.bias[k] = torch.uniform(-stdv,stdv)
45+
end
46+
47+
end
48+
end
49+
50+
function SpatialFullConvolutionMap:updateOutput(input)
51+
input.nn.SpatialFullConvolutionMap_updateOutput(self, input)
52+
return self.output
53+
end
54+
55+
function SpatialFullConvolutionMap:updateGradInput(input, gradOutput)
56+
input.nn.SpatialFullConvolutionMap_updateGradInput(self, input, gradOutput)
57+
return self.gradInput
58+
end
59+
60+
function SpatialFullConvolutionMap:accGradParameters(input, gradOutput, scale)
61+
return input.nn.SpatialFullConvolutionMap_accGradParameters(self, input, gradOutput, scale)
62+
end

TanhShrink.lua

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
local TanhShrink, parent = torch.class('nn.TanhShrink','nn.Module')
2+
3+
function TanhShrink:__init()
4+
parent.__init(self)
5+
self.tanh = nn.Tanh()
6+
end
7+
8+
function TanhShrink:updateOutput(input)
9+
local th = self.tanh:updateOutput(input)
10+
self.output:resizeAs(input):copy(input)
11+
self.output:add(-1,th)
12+
return self.output
13+
end
14+
15+
function TanhShrink:updateGradInput(input, gradOutput)
16+
local dth = self.tanh:updateGradInput(input,gradOutput)
17+
self.gradInput:resizeAs(input):copy(gradOutput)
18+
self.gradInput:add(-1,dth)
19+
return self.gradInput
20+
end

WeightedMSECriterion.lua

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
local WeightedMSECriterion, parent = torch.class('nn.WeightedMSECriterion','nn.MSECriterion')
2+
3+
function WeightedMSECriterion:__init(w)
4+
parent.__init(self)
5+
self.weight = w:clone()
6+
self.buffer = torch.Tensor()
7+
end
8+
9+
function WeightedMSECriterion:updateOutput(input,target)
10+
self.buffer:resizeAs(input):copy(target)
11+
if input:dim() - 1 == self.weight:dim() then
12+
for i=1,input:size(1) do
13+
self.buffer[i]:cmul(self.weight)
14+
end
15+
else
16+
self.buffer:cmul(self.weight)
17+
end
18+
return input.nn.MSECriterion_updateOutput(self, input, self.buffer)
19+
end
20+
21+
function WeightedMSECriterion:updateGradInput(input, target)
22+
self.buffer:resizeAs(input):copy(target)
23+
if input:dim() - 1 == self.weight:dim() then
24+
for i=1,input:size(1) do
25+
self.buffer[i]:cmul(self.weight)
26+
end
27+
else
28+
self.buffer:cmul(self.weight)
29+
end
30+
return input.nn.MSECriterion_updateGradInput(self, input, self.buffer)
31+
end

generic/L1Cost.c

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/L1Cost.c"
3+
#else
4+
5+
static int nn_(L1Cost_updateOutput)(lua_State *L)
6+
{
7+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
8+
accreal sum;
9+
10+
sum = 0;
11+
TH_TENSOR_APPLY(real, input, sum += fabs(*input_data););
12+
13+
lua_pushnumber(L, sum);
14+
lua_setfield(L, 1, "output");
15+
16+
lua_pushnumber(L, sum);
17+
return 1;
18+
}
19+
20+
static int nn_(L1Cost_updateGradInput)(lua_State *L)
21+
{
22+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
23+
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
24+
25+
THTensor_(resizeAs)(gradInput, input);
26+
TH_TENSOR_APPLY2(real, gradInput, real, input,
27+
if (*input_data > 0)
28+
*gradInput_data = 1;
29+
else if (*input_data < 0)
30+
*gradInput_data = -1;
31+
else
32+
*gradInput_data = 0;);
33+
return 1;
34+
}
35+
36+
static const struct luaL_Reg nn_(L1Cost__) [] = {
37+
{"L1Cost_updateOutput", nn_(L1Cost_updateOutput)},
38+
{"L1Cost_updateGradInput", nn_(L1Cost_updateGradInput)},
39+
{NULL, NULL}
40+
};
41+
42+
static void nn_(L1Cost_init)(lua_State *L)
43+
{
44+
luaT_pushmetatable(L, torch_Tensor);
45+
luaT_registeratname(L, nn_(L1Cost__), "nn");
46+
lua_pop(L,1);
47+
}
48+
49+
#endif

generic/SpatialConvolutionMap.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ static int nn_(SpatialConvolutionMap_updateOutput)(lua_State *L)
1818
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
1919

2020
luaL_argcheck(L, input->nDimension == 3, 2, "3D tensor expected");
21-
luaL_argcheck(L, input->size[0] == nInputPlane, 2, "invalid number of input planes");
21+
luaL_argcheck(L, input->size[0] >= nInputPlane, 2, "invalid number of input planes");
2222
luaL_argcheck(L, input->size[2] >= kW && input->size[1] >= kH, 2, "input image smaller than kernel size");
2323

2424
THTensor_(resize3d)(output, nOutputPlane,

0 commit comments

Comments
 (0)