diff --git a/lib/THCUNN/LenSoftMax.cu b/lib/THCUNN/LenSoftMax.cu new file mode 100644 index 00000000..07f3b6f5 --- /dev/null +++ b/lib/THCUNN/LenSoftMax.cu @@ -0,0 +1,113 @@ +#include "THCUNN.h" +#include "THCHalf.h" +#include "THCHalfAutoNumerics.cuh" + +#define LENSOFTMAX_THREADS 128 + +template +__global__ void cunn_LenSoftMax_updateOutput_kernel( + T *output, T *input, int nframe, int dim, IndexT *len) +{ + __shared__ AccumT buffer[LENSOFTMAX_THREADS+1]; + T *input_k = input + blockIdx.x*dim + blockIdx.y + blockIdx.z; + T *output_k = output + blockIdx.x*dim + blockIdx.y + blockIdx.z; + + int i_start = threadIdx.x; + int i_end = ScalarConvert::to(len[blockIdx.x]); + int i_step = blockDim.x; + + // max? + buffer[threadIdx.x] = -THCNumerics::max(); + for (int i=i_start; i::to(z); + if (buffer[threadIdx.x] < zAcc) + buffer[threadIdx.x] = zAcc; + } + + + __syncthreads(); + + // reduce + if (threadIdx.x == 0) + { + AccumT max_k = -THCNumerics::max(); + for (int i=0; i::to(buffer[LENSOFTMAX_THREADS]); + buffer[threadIdx.x] = ScalarConvert::to(0); + for (int i=i_start; i::exp(input_k[i]-max_k); + buffer[threadIdx.x] += ScalarConvert::to(z); + output_k[i] = z; + } + + __syncthreads(); + + // reduce + if (threadIdx.x == 0) + { + AccumT sum_k = ScalarConvert::to(0); + for (int i=0; i::to(buffer[LENSOFTMAX_THREADS]); + for (int i=i_start; i +__global__ void cunn_LenSoftMax_updateGradInput_kernel( + T *gradInput, T *output, T *gradOutput, int nframe, int dim, IndexT *len) +{ + __shared__ AccumT buffer[LENSOFTMAX_THREADS]; + T *gradInput_k = gradInput + blockIdx.x*dim + blockIdx.y + blockIdx.z; + T *output_k = output + blockIdx.x*dim + blockIdx.y + blockIdx.z; + T *gradOutput_k = gradOutput + blockIdx.x*dim + blockIdx.y + blockIdx.z; + + int i_start = threadIdx.x; + int i_end = ScalarConvert::to(len[blockIdx.x]); + int i_step = blockDim.x; + + // sum? + buffer[threadIdx.x] = ScalarConvert::to(0); + for (int i=i_start; i::to(gradOutput_k[i] * output_k[i]); + + __syncthreads(); + + // reduce + if (threadIdx.x == 0) + { + AccumT sum_k = ScalarConvert::to(0); + for (int i=0; i::to(buffer[0]); + for (int i=i_start; inDimension != 2) && (len->nDimension != 1)) + { + THError("2D tensor expected for input, 1D tensor expected for len"); + } + + input = THCTensor_(newContiguous)(state, input); + THCTensor_(resizeAs)(state, output, input); + THCTensor_(zero)(state, output); + long batchSize = input->size[0], dim = input->size[1]; + long blocksY = 1, blocksZ = 1; + + dim3 blocks(batchSize, blocksY, blocksZ); + dim3 threads(LENSOFTMAX_THREADS); + cunn_LenSoftMax_updateOutput_kernel<<>>( + THCTensor_(data)(state, output), + THCTensor_(data)(state, input), + batchSize, dim, THCIndexTensor_(data)(state, len) + ); + THCudaCheck(cudaGetLastError()); + + THCTensor_(free)(state, input); +} + +void THNN_(LenSoftMax_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *output, + THCIndexTensor *len) +{ + THCUNN_check_nElement(state, input, gradOutput); + THCUNN_assertSameGPU(state, 3, output, gradOutput, gradInput); + + if ((gradInput->nDimension != 2) && (len->nDimension != 1)) + { + THError("2D tensor expected for input, 1D tensor expected for len"); + } + + + output = THCTensor_(newContiguous)(state, output); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); + + THCTensor_(resizeAs)(state, gradInput, output); + THCTensor_(zero)(state, gradInput); + long batchSize = gradInput->size[0], dim = gradInput->size[1]; + long blocksY = 1, blocksZ = 1; + + dim3 blocks(batchSize, blocksY, blocksZ); + dim3 threads(LENSOFTMAX_THREADS); + cunn_LenSoftMax_updateGradInput_kernel<<>>( + THCTensor_(data)(state, gradInput), + THCTensor_(data)(state, output), + THCTensor_(data)(state, gradOutput), + batchSize, dim, THCIndexTensor_(data)(state, len) + ); + THCudaCheck(cudaGetLastError()); + + THCTensor_(free)(state, gradOutput); + THCTensor_(free)(state, output); +} + +#endif diff --git a/lib/THCUNN/generic/THCUNN.h b/lib/THCUNN/generic/THCUNN.h index df186b18..779ae0e4 100644 --- a/lib/THCUNN/generic/THCUNN.h +++ b/lib/THCUNN/generic/THCUNN.h @@ -1070,6 +1070,20 @@ TH_API void THNN_(SoftMax_updateGradInput)( THCTensor *gradInput, THCTensor *output); +TH_API void THNN_(LenSoftMax_updateOutput)( + THCState *state, + THCTensor *input, + THCTensor *output, + THCIndexTensor *len); + +TH_API void THNN_(LenSoftMax_updateGradInput)( + THCState *state, + THCTensor *input, + THCTensor *gradOutput, + THCTensor *gradInput, + THCTensor *output, + THCIndexTensor *len); + TH_API void THNN_(SoftPlus_updateOutput)( THCState *state, THCTensor *input,