Skip to content

Commit 61334b6

Browse files
xternalzjasonkuen
authored and
jasonkuen
committed
Channel-Wise RReLU
1 parent 9c8f2bb commit 61334b6

File tree

5 files changed

+239
-82
lines changed

5 files changed

+239
-82
lines changed

RReLU.lua

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
local ffi = require 'ffi'
22
local RReLU, parent = torch.class('nn.RReLU', 'nn.Module')
33

4-
function RReLU:__init(l, u, ip)
4+
function RReLU:__init(l, u, ip, cw)
55
parent.__init(self)
66
self.lower = l or 1/8
77
self.upper = u or 1/3
88
assert(self.lower <= self.upper and self.lower >= 0 and self.upper >= 0)
99
self.noise = torch.Tensor()
1010
self.train = true
1111
self.inplace = ip or false
12+
self.channelwise = cw or false
1213
end
1314

1415
function RReLU:updateOutput(input)
@@ -21,6 +22,7 @@ function RReLU:updateOutput(input)
2122
self.upper,
2223
self.train,
2324
self.inplace,
25+
self.channelwise,
2426
gen
2527
)
2628
return self.output
@@ -35,13 +37,14 @@ function RReLU:updateGradInput(input, gradOutput)
3537
self.lower,
3638
self.upper,
3739
self.train,
38-
self.inplace
40+
self.inplace,
41+
self.channelwise
3942
)
4043
return self.gradInput
4144
end
4245

4346
function RReLU:__tostring__()
44-
return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper)
47+
return string.format('%s (l:%f, u:%f, channel-wise:%s)', torch.type(self), self.lower, self.upper, self.channelwise)
4548
end
4649

4750
function RReLU:clearState()

doc/transfer.md

+1
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ m=nn.ReLU(
290290
l, -- minimum factor for negative inputs, default: 1/8;
291291
u, -- maximum factor for negative inputs, default: 1/3;
292292
inplace -- if true the result will be written to the input tensor, default: false;
293+
cw -- if true all elements of the same channel share the same `a`, default: false;
293294
)
294295
```
295296
If `l == u` a RReLU effectively becomes a LeakyReLU. Regardless of operating in in-place mode a RReLU will internally allocate an input-sized `noise` tensor to store random factors for negative inputs. The backward() operation assumes that forward() has been called before.

lib/THNN/generic/RReLU.c

+201-53
Original file line numberDiff line numberDiff line change
@@ -11,68 +11,156 @@ void THNN_(RReLU_updateOutput)(
1111
real upper,
1212
bool train,
1313
bool inplace,
14+
bool channelwise,
1415
THGenerator *generator)
1516
{
16-
if (train)
17+
if (channelwise && train)
1718
{
18-
// get default random generator
19-
THTensor_(resizeAs)(noise, input);
20-
if (inplace)
19+
long bs, ks;
20+
long nOutputPlane;
2121
{
22-
TH_TENSOR_APPLY2(real, input, real, noise,
23-
if (*input_data <= 0)
24-
{
25-
const real r = (real)THRandom_uniform(generator, lower, upper);
26-
*input_data = (*input_data) * r;
27-
*noise_data = r;
28-
}
29-
else
30-
{
31-
*noise_data = 1;
32-
}
33-
);
34-
THTensor_(set)(output, input);
22+
long input_ndim = THTensor_(nDimension)(input);
23+
switch (input_ndim)
24+
{
25+
case 1:
26+
bs = 1;
27+
ks = 1;
28+
break;
29+
case 2:
30+
bs = input->size[0];
31+
ks = 1;
32+
break;
33+
case 3:
34+
bs = 1;
35+
ks = input->size[1] * input->size[2];
36+
break;
37+
case 4:
38+
bs = input->size[0];
39+
ks = input->size[2] * input->size[3];
40+
break;
41+
}
42+
nOutputPlane = input->size[(input_ndim + 1) % 2];
3543
}
44+
// get default random generator
45+
if (inplace)
46+
THTensor_(resizeAs)(noise, input);
3647
else
48+
THTensor_(resize1d)(noise, nOutputPlane);
49+
50+
real *output_data = NULL;
51+
real *input_data = THTensor_(data)(input);
52+
real *noise_data = THTensor_(data)(noise);
53+
if (!inplace)
3754
{
3855
THTensor_(resizeAs)(output, input);
39-
TH_TENSOR_APPLY3(real, input, real, output, real, noise,
40-
if (*input_data <= 0)
41-
{
42-
const real r = (real)THRandom_uniform(generator, lower, upper);
43-
*output_data = (*input_data) * r;
44-
*noise_data = r;
45-
}
56+
output_data = THTensor_(data)(output);
57+
}
58+
THTensor *channel_noise = THTensor_(newWithSize1d)(nOutputPlane);
59+
real *channel_noise_data = THTensor_(data)(channel_noise);
60+
61+
THIndex_t i, j, k;
62+
#pragma omp parallel for private(j)
63+
for (j = 0; j < nOutputPlane; ++j)
64+
channel_noise_data[j] = (real)THRandom_uniform(generator, lower, upper);
65+
#pragma omp parallel for private(j,k)
66+
for (i = 0; i < bs; ++i)
67+
{
68+
real* n_input_data = input_data + i*nOutputPlane*ks;
69+
real* n_output_data = NULL;
70+
real* n_noise_data = NULL;
71+
if (inplace)
72+
n_noise_data = noise_data + i*nOutputPlane*ks;
73+
else
74+
n_output_data = output_data + i*nOutputPlane*ks;
75+
for (j = 0; j < nOutputPlane; ++j)
76+
{
77+
const real r = channel_noise_data[j];
78+
for (k = 0; k < ks; ++k)
79+
if (inplace)
80+
if (n_input_data[k] <= 0)
81+
{
82+
n_input_data[k] = r * n_input_data[k];
83+
n_noise_data[k] = r;
84+
}
85+
else
86+
n_noise_data[k] = 1;
87+
else
88+
n_output_data[k] = (n_input_data[k] > 0) ? n_input_data[k] : r * n_input_data[k];
89+
n_input_data += ks;
90+
if (inplace)
91+
n_noise_data += ks;
4692
else
47-
{
48-
*output_data = *input_data;
49-
*noise_data = 1;
50-
}
51-
);
93+
n_output_data += ks;
94+
}
5295
}
96+
if (inplace)
97+
THTensor_(set)(output, input);
98+
else
99+
THTensor_(set)(noise, channel_noise);
53100
}
54101
else
55102
{
56-
const real negSlope = (lower + upper) / 2;
57-
if (inplace)
103+
if (train)
58104
{
59-
TH_TENSOR_APPLY(real, input,
60-
if (*input_data <= 0)
61-
{
62-
*input_data = *input_data * negSlope;
63-
}
64-
);
65-
THTensor_(set)(output, input);
105+
// get default random generator
106+
THTensor_(resizeAs)(noise, input);
107+
if (inplace)
108+
{
109+
TH_TENSOR_APPLY2(real, input, real, noise,
110+
if (*input_data <= 0)
111+
{
112+
const real r = (real)THRandom_uniform(generator, lower, upper);
113+
*input_data = (*input_data) * r;
114+
*noise_data = r;
115+
}
116+
else
117+
{
118+
*noise_data = 1;
119+
}
120+
);
121+
THTensor_(set)(output, input);
122+
}
123+
else
124+
{
125+
THTensor_(resizeAs)(output, input);
126+
TH_TENSOR_APPLY3(real, input, real, output, real, noise,
127+
if (*input_data <= 0)
128+
{
129+
const real r = (real)THRandom_uniform(generator, lower, upper);
130+
*output_data = (*input_data) * r;
131+
*noise_data = r;
132+
}
133+
else
134+
{
135+
*output_data = *input_data;
136+
*noise_data = 1;
137+
}
138+
);
139+
}
66140
}
67141
else
68142
{
69-
THTensor_(resizeAs)(output, input);
70-
TH_TENSOR_APPLY2(real, input, real, output,
71-
const real r = (*input_data) <= 0 ? negSlope : 1;
72-
*output_data = *input_data * r;
73-
);
143+
const real negSlope = (lower + upper) / 2;
144+
if (inplace)
145+
{
146+
TH_TENSOR_APPLY(real, input,
147+
if (*input_data <= 0)
148+
{
149+
*input_data = *input_data * negSlope;
150+
}
151+
);
152+
THTensor_(set)(output, input);
153+
}
154+
else
155+
{
156+
THTensor_(resizeAs)(output, input);
157+
TH_TENSOR_APPLY2(real, input, real, output,
158+
const real r = (*input_data) <= 0 ? negSlope : 1;
159+
*output_data = *input_data * r;
160+
);
161+
}
74162
}
75-
}
163+
}
76164
}
77165

78166
void THNN_(RReLU_updateGradInput)(
@@ -84,24 +172,84 @@ void THNN_(RReLU_updateGradInput)(
84172
real lower,
85173
real upper,
86174
bool train,
87-
bool inplace)
175+
bool inplace,
176+
bool channelwise)
88177
{
89178
if (train && upper - lower > 1E-6) // e.g. if upper == lower, RReLU behaves like LeakyReLU
90179
{
91-
// multiply the gradient by the noise tensor
92-
if (inplace)
180+
if (channelwise && !inplace)
93181
{
94-
THTensor_(cmul)(gradOutput, gradOutput, noise);
95-
THTensor_(set)(gradInput, gradOutput);
182+
long bs, ks;
183+
long nOutputPlane;
184+
{
185+
long input_ndim = THTensor_(nDimension)(input);
186+
switch (input_ndim)
187+
{
188+
case 1:
189+
bs = 1;
190+
ks = 1;
191+
break;
192+
case 2:
193+
bs = input->size[0];
194+
ks = 1;
195+
break;
196+
case 3:
197+
bs = 1;
198+
ks = input->size[1] * input->size[2];
199+
break;
200+
case 4:
201+
bs = input->size[0];
202+
ks = input->size[2] * input->size[3];
203+
break;
204+
}
205+
nOutputPlane = input->size[(input_ndim + 1) % 2];
206+
}
207+
208+
const real *input_data = THTensor_(data)(input);
209+
const real *gradOutput_data = THTensor_(data)(gradOutput);
210+
THTensor_(resizeAs)(gradInput, input);
211+
real *gradInput_data = THTensor_(data)(gradInput);
212+
const real *noise_data = THTensor_(data)(noise);
213+
214+
THIndex_t i, j, k;
215+
#pragma omp parallel for private(j,k)
216+
for (i = 0; i < bs; ++i)
217+
{
218+
const real *n_input_data = input_data + i*nOutputPlane*ks;
219+
const real *n_gradOutput_data = gradOutput_data + i*nOutputPlane*ks;
220+
real *n_gradInput_data = gradInput_data + i*nOutputPlane*ks;
221+
222+
for (j = 0; j < nOutputPlane; ++j)
223+
{
224+
const real r = noise_data[j];
225+
for (k = 0; k < ks; ++k)
226+
if (n_input_data[k] > 0)
227+
n_gradInput_data[k] = n_gradOutput_data[k];
228+
else
229+
n_gradInput_data[k] = n_gradOutput_data[k] * r;
230+
n_input_data += ks;
231+
n_gradInput_data += ks;
232+
n_gradOutput_data += ks;
233+
}
234+
}
96235
}
97236
else
98237
{
99-
THTensor_(resizeAs)(gradInput, input);
100-
THTensor_(cmul)(gradInput, gradOutput, noise);
101-
}
238+
// multiply the gradient by the noise tensor
239+
if (inplace)
240+
{
241+
THTensor_(cmul)(gradOutput, gradOutput, noise);
242+
THTensor_(set)(gradInput, gradOutput);
243+
}
244+
else
245+
{
246+
THTensor_(resizeAs)(gradInput, input);
247+
THTensor_(cmul)(gradInput, gradOutput, noise);
248+
}
249+
}
102250
}
103251
else
104-
{
252+
{
105253
// use constant factor for negative input values
106254
const real negSlope = (lower + upper) / 2;
107255
if (inplace)

lib/THNN/generic/THNN.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ TH_API void THNN_(RReLU_updateOutput)(
291291
real upper,
292292
bool train,
293293
bool inplace,
294+
bool channelwise,
294295
THGenerator *generator);
295296
TH_API void THNN_(RReLU_updateGradInput)(
296297
THNNState *state,
@@ -301,7 +302,8 @@ TH_API void THNN_(RReLU_updateGradInput)(
301302
real lower,
302303
real upper,
303304
bool train,
304-
bool inplace);
305+
bool inplace,
306+
bool channelwise);
305307

306308
TH_API void THNN_(Sigmoid_updateOutput)(
307309
THNNState *state,

0 commit comments

Comments
 (0)