-
Notifications
You must be signed in to change notification settings - Fork 0
/
RReLU.lua
50 lines (45 loc) · 1.17 KB
/
RReLU.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
local ffi = require 'ffi'
local RReLU, parent = torch.class('nn.RReLU', 'nn.Module')
function RReLU:__init(l, u, ip)
parent.__init(self)
self.lower = l or 1/8
self.upper = u or 1/3
assert(self.lower <= self.upper and self.lower >= 0 and self.upper >= 0)
self.noise = torch.Tensor()
self.train = true
self.inplace = ip or false
end
function RReLU:updateOutput(input)
local gen = ffi.typeof('THGenerator**')(torch._gen)[0]
input.THNN.RReLU_updateOutput(
input:cdata(),
self.output:cdata(),
self.noise:cdata(),
self.lower,
self.upper,
self.train,
self.inplace,
gen
)
return self.output
end
function RReLU:updateGradInput(input, gradOutput)
input.THNN.RReLU_updateGradInput(
input:cdata(),
gradOutput:cdata(),
self.gradInput:cdata(),
self.noise:cdata(),
self.lower,
self.upper,
self.train,
self.inplace
)
return self.gradInput
end
function RReLU:__tostring__()
return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper)
end
function RReLU:clearState()
if self.noise then self.noise:set() end
return parent.clearState(self)
end