Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Table input for nn.Bottle #1183

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 92 additions & 26 deletions Bottle.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,50 +13,115 @@ function Bottle:__init(module, nInputDim, nOutputDim)
self.modules[1] = module
end

function Bottle:dim(input)
if torch.type(input) == 'table' then
return input[1]:dim()
else
return input:dim()
end
end

function Bottle:size(input)
if torch.type(input) == 'table' then
local size = {}
for i=1,#input do
table.insert(size, self:size(input[i]))
end
return size
else
return torch.LongTensor(input:size())
end
end

function Bottle:bottle(input, batchDims)
if torch.type(input) == 'table' then
local inShape = {}
local inSize, squeezeSize = nil, nil
for i=1,#input do
shape, inSize, squeezeSize = self:bottle(input[i], batchDims)
table.insert(inShape, shape)
end
return inShape, inSize, squeezeSize
else
local inSize = torch.LongTensor(input:size())
local squeezeSize = inSize[{{1, batchDims - 1}}]:prod()
local inShape = inSize[{{batchDims, input:dim()}}]:clone()
inShape[{{1}}]:mul(squeezeSize)
-- Forward with the module's dimension
return inShape, inSize, squeezeSize
end
end

function Bottle:squeeze(input, inShape)
if torch.type(input) == 'table' then
input_ = {}
for i=1,#input do
input_[i] = self:squeeze(input[i], inShape[i])
end
return input_
else
return input:view(unpack(inShape:totable()))
end
end

function Bottle:unbottle(output, batchDims, inSize, squeezeSize)
if torch.type(output) == 'table' then
local outShape = {}
for i=1,#output do
local output_, outShape_ = self:unbottle(output[i], batchDims, inSize, squeezeSize)
table.insert(outShape, outShape_)
output[i] = output_
end
return output, outShape
else
local outShape = torch.LongTensor(output:size())
if math.abs(self.dimDelta) > 0 then
inSize = inSize[{{1, inSize:size(1) - self.dimDelta}}]:clone()
end
inSize[{{batchDims, inSize:size(1)}}]:copy(outShape)
inSize[{{batchDims}}]:div(squeezeSize)
output = output:view(unpack(torch.totable(inSize)))
return output, outShape
end
end

function Bottle:updateOutput(input)
-- first batchDims dimensions will be fused
local batchDims = input:dim() - self.nInputDim + 1
local batchDims = self:dim(input) - self.nInputDim + 1
-- see if bottle is required
if batchDims > 1 then
-- bottle the first dims
local inSize = torch.LongTensor(input:size())
local squeezeSize = inSize[{{1, batchDims - 1}}]:prod()
self.inShape:copy(inSize[{{batchDims, input:dim()}}])
self.inShape[{{1}}]:mul(squeezeSize)
local inShape, inSize, squeezeSize = self:bottle(input, batchDims)
self.inShape = inShape
local newInput = self:squeeze(input, inShape)
-- Forward with the module's dimension
local newInput = input:view(unpack(self.inShape:totable()))
local output = self.modules[1]:updateOutput(newInput)
assert(output:dim() == self.nOutputDim,
"Wrong number of output dims on module. Expected: " ..
self.nOutputDim .. ' but got ' ..
tostring(output and output:dim()))
self.outShape:copy(torch.LongTensor(output:size()))
if math.abs(self.dimDelta) > 0 then
inSize:resize(inSize:size(1) - self.dimDelta)
end
inSize[{{batchDims, inSize:size(1)}}]:copy(self.outShape)
inSize[{{batchDims}}]:div(squeezeSize)
assert(self:dim(output) == self.nOutputDim,
"Wrong number of output dims on module. Expected: " ..
self.nOutputDim .. ' but got ' ..
tostring(output and self:dim(output)))

-- unbottle
self.output:set(output:view(unpack(torch.totable(inSize))))
self.output, self.outShape = self:unbottle(output, batchDims, inSize, squeezeSize)
else
self.output:set(self.modules[1]:updateOutput(input))
self.output = self.modules[1]:updateOutput(input)
end
return self.output
end

function Bottle:updateGradInput(input, gradOutput)
if input:dim() > self.nInputDim then
local input_ = input:view(unpack(self.inShape:totable()))
local gradOutput_ = gradOutput:view(unpack(self.outShape:totable()))
if self:dim(input) > self.nInputDim then
local input_ = self:squeeze(input, self.inShape)
local gradOutput_ = self:squeeze(gradOutput, self.outShape)
self.modules[1]:updateGradInput(input_, gradOutput_)
if self.modules[1].gradInput then
self.gradInput:set(self.modules[1].gradInput:viewAs(input))
self.gradInput = self:squeeze(self.modules[1].gradInput, self:size(input))
else
self.gradInput = nil
end
else
if self.modules[1].gradInput then
self.gradInput:set(self.modules[1]:updateGradInput(input))
self.gradInput = self.modules[1]:updateGradInput(input)
else
self.gradInput = nil
end
Expand All @@ -65,9 +130,10 @@ function Bottle:updateGradInput(input, gradOutput)
end

function Bottle:accGradParameters(input, gradOutput, scale)
if input:dim() > self.nInputDim then
input = input:view(unpack(self.inShape:totable()))
gradOutput = gradOutput:view(unpack(self.outShape:totable()))
if self:dim(input) > self.nInputDim then
input = self:squeeze(input, self.inShape)
gradOutput = self:squeeze(gradOutput, self.outShape)
end
self.modules[1]:accGradParameters(input, gradOutput, scale)
end