diff --git a/LenSoftMax.lua b/LenSoftMax.lua new file mode 100644 index 000000000..c54a369ef --- /dev/null +++ b/LenSoftMax.lua @@ -0,0 +1,32 @@ +local LenSoftMax, parent = torch.class('nn.LenSoftMax', 'nn.Module') + +function LenSoftMax:__init() + parent.__init(self) + self.gradInput = {torch.Tensor()} +end + +function LenSoftMax:updateOutput(input) + local _input, _len = unpack(input) + _input.THNN.LenSoftMax_updateOutput( + _input:cdata(), + self.output:cdata(), + _len:cdata() + ) + return self.output +end + +function LenSoftMax:updateGradInput(input, gradOutput) + local _input, _len = unpack(input) + _input.THNN.LenSoftMax_updateGradInput( + _input:cdata(), + gradOutput:cdata(), + self.gradInput[1]:cdata(), + self.output:cdata(), + _len:cdata() + ) + if not self.gradInput[2] then + self.gradInput[2] = _len.new() + end + self.gradInput[2]:resizeAs(_len):zero() + return self.gradInput +end diff --git a/init.lua b/init.lua index 21ac7897a..f826e7e36 100755 --- a/init.lua +++ b/init.lua @@ -90,6 +90,7 @@ require('nn.LogSigmoid') require('nn.LogSoftMax') require('nn.Sigmoid') require('nn.SoftMax') +require('nn.LenSoftMax') require('nn.SoftMin') require('nn.SoftPlus') require('nn.SoftSign') diff --git a/lib/THNN/generic/LenSoftMax.c b/lib/THNN/generic/LenSoftMax.c new file mode 100644 index 000000000..5054319b3 --- /dev/null +++ b/lib/THNN/generic/LenSoftMax.c @@ -0,0 +1,116 @@ +#ifndef TH_GENERIC_FILE +#define TH_GENERIC_FILE "generic/LenSoftMax.c" +#else + +void THNN_(LenSoftMax_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + THIndexTensor *len) +{ + if ((input->nDimension != 2) && (len->nDimension != 1)) + { + THArgCheck(0, 2, "2D tensor expected for input, 1D tensor expected for len"); + } + + real *input_data, *output_data; + THIndex_t *len_data; + ptrdiff_t nframe = input->size[0], dim = input->size[1]; + ptrdiff_t t; + + input = THTensor_(newContiguous)(input); + THTensor_(resizeAs)(output, input); + + input_data = THTensor_(data)(input); + output_data = THTensor_(data)(output); + len_data = THIndexTensor_(data)(len); + +#pragma omp parallel for private(t) + for (t = 0; t < nframe; t++) + { + real *input_ptr = input_data + t*dim; + real *output_ptr = output_data + t*dim; + + real inputMax = -THInf; + accreal sum; + + ptrdiff_t d, ld = (ptrdiff_t)len_data[t]; + for (d = 0; d < ld; d++) + { + if (input_ptr[d] >= inputMax) inputMax = input_ptr[d]; + } + + sum = 0; + for (d = 0; d < ld; d++) + { + real z = exp(input_ptr[d] - inputMax); + output_ptr[d] = z; + sum += z; + } + for (d = ld; d < dim; d++) + { + output_ptr[d] = 0; + } + + for (d = 0; d < ld; d++) + { + output_ptr[d] *= 1/sum; + } + } + + THTensor_(free)(input); +} + +void THNN_(LenSoftMax_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + THTensor *output, + THIndexTensor *len) +{ + THNN_CHECK_SHAPE(input, gradOutput); + + if ((output->nDimension != 2) && (len->nDimension != 1)) + { + THError("2D tensor expected for input, 1D tensor expected for len"); + } + + real *gradInput_data, *gradOutput_data, *output_data; + THIndex_t *len_data; + ptrdiff_t nframe = output->size[0], dim = output->size[1]; + ptrdiff_t t; + + gradOutput = THTensor_(newContiguous)(gradOutput); + output = THTensor_(newContiguous)(output); + + THTensor_(resizeAs)(gradInput, output); + gradInput_data = THTensor_(data)(gradInput); + output_data = THTensor_(data)(output); + gradOutput_data = THTensor_(data)(gradOutput); + len_data = THIndexTensor_(data)(len); + +#pragma omp parallel for private(t) + for (t = 0; t < nframe; t++) + { + real *gradInput_ptr = gradInput_data + t*dim; + real *output_ptr = output_data + t*dim; + real *gradOutput_ptr = gradOutput_data + t*dim; + + ptrdiff_t d, ld = (ptrdiff_t)len_data[t]; + accreal sum = 0; + for (d = 0; d < ld; d++) + sum += (accreal)gradOutput_ptr[d] * output_ptr[d]; + + for (d = 0; d < ld; d++) + gradInput_ptr[d] = output_ptr[d] * (gradOutput_ptr[d] - sum); + + for (d = ld; d < dim; d++) + gradInput_ptr[d] = 0; + } + + THTensor_(free)(gradOutput); + THTensor_(free)(output); +} + +#endif \ No newline at end of file diff --git a/lib/THNN/generic/THNN.h b/lib/THNN/generic/THNN.h index 2c4aabfe8..219deb545 100644 --- a/lib/THNN/generic/THNN.h +++ b/lib/THNN/generic/THNN.h @@ -431,6 +431,19 @@ TH_API void THNN_(SoftMax_updateGradInput)( THTensor *gradInput, THTensor *output); +TH_API void THNN_(LenSoftMax_updateOutput)( + THNNState *state, + THTensor *input, + THTensor *output, + THIndexTensor *len); +TH_API void THNN_(LenSoftMax_updateGradInput)( + THNNState *state, + THTensor *input, + THTensor *gradOutput, + THTensor *gradInput, + THTensor *output, + THIndexTensor *len); + TH_API void THNN_(SoftPlus_updateOutput)( THNNState *state, THTensor *input, diff --git a/lib/THNN/init.c b/lib/THNN/init.c index cd5ddb9ce..1f2095a19 100644 --- a/lib/THNN/init.c +++ b/lib/THNN/init.c @@ -143,6 +143,9 @@ #include "generic/SoftMax.c" #include "THGenerateFloatTypes.h" +#include "generic/LenSoftMax.c" +#include "THGenerateFloatTypes.h" + #include "generic/SoftPlus.c" #include "THGenerateFloatTypes.h"