File tree 4 files changed +10
-9
lines changed
4 files changed +10
-9
lines changed Original file line number Diff line number Diff line change 16
16
17
17
function CAddTable :updateGradInput (input , gradOutput )
18
18
for i = 1 ,# input do
19
- self .gradInput [i ] = self .gradInput [i ] or torch . Tensor ()
19
+ self .gradInput [i ] = self .gradInput [i ] or input [ 1 ]. new ()
20
20
self .gradInput [i ]:resizeAs (input [i ])
21
21
self .gradInput [i ]:copy (gradOutput )
22
22
end
Original file line number Diff line number Diff line change @@ -13,8 +13,8 @@ function CDivTable:updateOutput(input)
13
13
end
14
14
15
15
function CDivTable :updateGradInput (input , gradOutput )
16
- self .gradInput [1 ] = self .gradInput [1 ] or torch . Tensor ()
17
- self .gradInput [2 ] = self .gradInput [2 ] or torch . Tensor ()
16
+ self .gradInput [1 ] = self .gradInput [1 ] or input [ 1 ]. new ()
17
+ self .gradInput [2 ] = self .gradInput [2 ] or input [ 1 ]. new ()
18
18
self .gradInput [1 ]:resizeAs (input [1 ]):copy (gradOutput ):cdiv (input [2 ])
19
19
self .gradInput [2 ]:resizeAs (input [2 ]):zero ():addcdiv (- 1 ,self .gradInput [1 ],input [2 ]):cmul (input [1 ])
20
20
return self .gradInput
Original file line number Diff line number Diff line change @@ -15,12 +15,13 @@ function CMulTable:updateOutput(input)
15
15
end
16
16
17
17
function CMulTable :updateGradInput (input , gradOutput )
18
- local tout = torch .Tensor ():resizeAs (self .output )
18
+ self .tout = self .tout or input [1 ].new ()
19
+ self .tout :resizeAs (self .output )
19
20
for i = 1 ,# input do
20
- self .gradInput [i ] = self .gradInput [i ] or torch . Tensor ()
21
+ self .gradInput [i ] = self .gradInput [i ] or input [ 1 ]. new ()
21
22
self .gradInput [i ]:resizeAs (input [i ]):copy (gradOutput )
22
- tout :copy (self .output ):cdiv (input [i ])
23
- self .gradInput [i ]:cmul (tout )
23
+ self . tout :copy (self .output ):cdiv (input [i ])
24
+ self .gradInput [i ]:cmul (self . tout )
24
25
end
25
26
return self .gradInput
26
27
end
Original file line number Diff line number Diff line change @@ -13,8 +13,8 @@ function CSubTable:updateOutput(input)
13
13
end
14
14
15
15
function CSubTable :updateGradInput (input , gradOutput )
16
- self .gradInput [1 ] = self .gradInput [1 ] or torch . Tensor ()
17
- self .gradInput [2 ] = self .gradInput [2 ] or torch . Tensor ()
16
+ self .gradInput [1 ] = self .gradInput [1 ] or input [ 1 ]. new ()
17
+ self .gradInput [2 ] = self .gradInput [2 ] or input [ 1 ]. new ()
18
18
self .gradInput [1 ]:resizeAs (input [1 ]):copy (gradOutput )
19
19
self .gradInput [2 ]:resizeAs (input [1 ]):copy (gradOutput ):mul (- 1 )
20
20
return self .gradInput
You can’t perform that action at this time.
0 commit comments