diff --git a/LookupTable.lua b/LookupTable.lua index a999e7e8e..36341f259 100644 --- a/LookupTable.lua +++ b/LookupTable.lua @@ -32,6 +32,13 @@ function LookupTable:setPadding(paddingValue) return self end +function LookupTable:zeroPaddingWeight() + if self.paddingValue > 0 then + self.weight:select(1, self.paddingValue):fill(0) + end + return self +end + function LookupTable:scaleGradByFreq() self.shouldScaleGradByFreq = true return self diff --git a/doc/convolution.md b/doc/convolution.md index b7b92ac8c..0e0f407c6 100644 --- a/doc/convolution.md +++ b/doc/convolution.md @@ -263,6 +263,46 @@ Outputs something like: [torch.DoubleTensor of dimension 2x4x3] ``` +`LookupTable` internally maintains a `weight` tensor and a `gradWeight` tensor in the same size. +To let sub tensor `gradWeight[n]` unchanged during `backward()` call, pass index `n` as `paddingValue`. +Example: +```lua +-- a lookup table containing 10 tensors of size 3 +-- index 2 interpreted as padding value +paddingValue = 2 +module = nn.LookupTable(10, 3, paddingValue) +gradWeight = module.gradWeight +input = torch.LongTensor{1,2,1,10} -- input size is 4 + +-- before +print(gradWeight[paddingValue]) +print(gradWeight[1]) + +-- back propagation +gradOutput = torch.ones(4, 10, 3) +module:backward(input, gradOutput) + +-- after +print(gradWeight[paddingValue]) -- unchanged +print(gradWeight[1]) -- changed, as input[1] = 1 and input[3] = 1 +``` +The default `paddingValue` is `0`, allowing the whole `gradWeight` always gets updated during `backward()` call. + +Method `zeroPaddingWeight()` enforces the sub tensor `weight[n]` be all zeros for index `n = paddingValue`, which also affects the `forward()` call. +Example: +```lua +paddingValue = 2 +module = nn.LookupTable(10, 3, paddingValue):zeroPaddingWeight() +weight = module.weight +print(weight[paddingValue]) -- all zeros + +input = torch.LongTensor{1,1,2,10} -- input[3] = 2, the paddingValue +output = module:forward(input) +print(output) -- output[3] are all zeros +``` + +Remark: `paddingValue` and `zeroPaddingWeight()` are useful for NLP task where the user wants a "placeholder" for the token indicating "unknown" (a.k.a. out-of-vocabulary token). + ## Spatial Modules ## Excluding an optional batch dimension, spatial layers expect a 3D Tensor as input. The diff --git a/test.lua b/test.lua index 91d99c9f6..45fda029d 100644 --- a/test.lua +++ b/test.lua @@ -4179,6 +4179,23 @@ function nntest.LookupTable() end local err = padw_sum - padw:sum() mytester:assertlt(err,precision, 'padding update error ') + -- test whether padding weights are set to zeros + local paddingValue = math.random(totalIndex) + local module = nn.LookupTable(totalIndex, entry_size, paddingValue):zeroPaddingWeight() + local padw = module.weight:select(1,paddingValue) + local input = torch.IntTensor(nIndex) + for i = 1, 100 do + input:apply( + function() -- set randomly half of the input as padding + if torch.random(2) == 1 then return paddingValue end + return torch.random(totalIndex) + end) + local y = module:updateOutput(input) + module:updateGradInput(input, y) + module:accUpdateGradParameters(input, y, 0.1) + end + local err = padw:sum() + mytester:assertlt(err,precision, 'padding update error ') end function nntest.AddConstant()