Skip to content

Commit 1d5b59b

Browse files
Fixed Tensor alloc for some modules (For CUDA)
1 parent b8addac commit 1d5b59b

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

CAddTable.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616

1717
function CAddTable:updateGradInput(input, gradOutput)
1818
for i=1,#input do
19-
self.gradInput[i] = self.gradInput[i] or torch.Tensor()
19+
self.gradInput[i] = self.gradInput[i] or input[1].new()
2020
self.gradInput[i]:resizeAs(input[i])
2121
self.gradInput[i]:copy(gradOutput)
2222
end

CDivTable.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ function CDivTable:updateOutput(input)
1313
end
1414

1515
function CDivTable:updateGradInput(input, gradOutput)
16-
self.gradInput[1] = self.gradInput[1] or torch.Tensor()
17-
self.gradInput[2] = self.gradInput[2] or torch.Tensor()
16+
self.gradInput[1] = self.gradInput[1] or input[1].new()
17+
self.gradInput[2] = self.gradInput[2] or input[1].new()
1818
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput):cdiv(input[2])
1919
self.gradInput[2]:resizeAs(input[2]):zero():addcdiv(-1,self.gradInput[1],input[2]):cmul(input[1])
2020
return self.gradInput

CMulTable.lua

+5-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ function CMulTable:updateOutput(input)
1515
end
1616

1717
function CMulTable:updateGradInput(input, gradOutput)
18-
local tout = torch.Tensor():resizeAs(self.output)
18+
self.tout = self.tout or input[1].new()
19+
self.tout:resizeAs(self.output)
1920
for i=1,#input do
20-
self.gradInput[i] = self.gradInput[i] or torch.Tensor()
21+
self.gradInput[i] = self.gradInput[i] or input[1].new()
2122
self.gradInput[i]:resizeAs(input[i]):copy(gradOutput)
22-
tout:copy(self.output):cdiv(input[i])
23-
self.gradInput[i]:cmul(tout)
23+
self.tout:copy(self.output):cdiv(input[i])
24+
self.gradInput[i]:cmul(self.tout)
2425
end
2526
return self.gradInput
2627
end

CSubTable.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ function CSubTable:updateOutput(input)
1313
end
1414

1515
function CSubTable:updateGradInput(input, gradOutput)
16-
self.gradInput[1] = self.gradInput[1] or torch.Tensor()
17-
self.gradInput[2] = self.gradInput[2] or torch.Tensor()
16+
self.gradInput[1] = self.gradInput[1] or input[1].new()
17+
self.gradInput[2] = self.gradInput[2] or input[1].new()
1818
self.gradInput[1]:resizeAs(input[1]):copy(gradOutput)
1919
self.gradInput[2]:resizeAs(input[1]):copy(gradOutput):mul(-1)
2020
return self.gradInput

0 commit comments

Comments
 (0)