-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLinearTensorD3.lua
40 lines (33 loc) · 1.41 KB
/
LinearTensorD3.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
-- LinearTensorD3: [bz * L * xDim] --> [bz * L * oDim]
local LinearTensorD3, parent = torch.class('nn.LinearTensorD3', 'nn.Linear')
function LinearTensorD3:__init(inputSize, outputSize)
parent.__init(self, inputSize, outputSize)
end
function LinearTensorD3:updateOutput(input)
--input: bz * L * xDim
--ouput: bz * L * oDim
assert(input:dim() == 3)
local bz, L, xDim = unpack(input:size():totable())
self.output:resize(bz*L, self.bias:size(1))
local inputView = input:view(bz*L, xDim)
self.output:copy(parent.updateOutput(self, inputView))
self.output:resize(bz, L, self.bias:size(1))
return self.output
end
function LinearTensorD3:updateGradInput(input, gradOutput)
assert(gradOutput:dim() == 3)
local bz, L, xDim = unpack(input:size():totable())
local inputView = input:view(bz*L, xDim)
local gradOutputView = gradOutput:view(bz*L, self.bias:size(1))
self.gradInput:resize(bz*L, xDim)
self.gradInput:copy(parent.updateGradInput(self, inputView, gradOutputView))
self.gradInput:resize(bz, L, xDim)
return self.gradInput
end
function LinearTensorD3:accGradParameters(input, gradOutput, scale)
scale = scale or 1
local bz, L, xDim = unpack(input:size():totable())
local inputView = input:view(bz*L, xDim)
local gradOutputView = gradOutput:view(bz*L, self.bias:size(1))
parent.accGradParameters(self, inputView, gradOutputView, scale)
end