Skip to content

Commit

Permalink
add BCHW layout
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Jun 12, 2017
1 parent d8ef434 commit 4543489
Show file tree
Hide file tree
Showing 7 changed files with 339 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
**/*.egg-info
**/.eggs
**/.ipynb_checkpoints/*
*.o
45 changes: 45 additions & 0 deletions script/functions/stn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,48 @@ def backward(self, grad_output):
grad_input2 = grad_input2.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c)
return grad_input1, grad_input2



class STNFunctionBCHW(Function):
def forward(self, input1, input2):
self.input1 = input1
self.input2 = input2
self.device_c = ffi.new("int *")
output = torch.zeros(input1.size()[0], input1.size()[1], input2.size()[2], input2.size()[3])
#print('decice %d' % torch.cuda.current_device())
self.device = torch.cuda.current_device()
self.device_c[0] = self.device
if not input1.is_cuda:
my_lib.BilinearSamplerBCHW_updateOutput(input1, input2, output)
else:

output = output.transpose(1,2).transpose(2,3)
input1 = input1.transpose(1,2).transpose(2,3)
input2 = input2.transpose(1,2).transpose(2,3)

output = output.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c)
output = output.transpose(2,3).transpose(1,2)

return output

def backward(self, grad_output):
grad_input1 = torch.zeros(self.input1.size())
grad_input2 = torch.zeros(self.input2.size())
#print('backward decice %d' % self.device)
if not grad_output.is_cuda:
my_lib.BilinearSamplerBCHW_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output)
else:
grad_input1 = grad_input1.transpose(1,2).transpose(2,3)
grad_input2 = grad_input2.transpose(1,2).transpose(2,3)
grad_output = grad_output.transpose(1,2).transpose(2,3)

grad_input1 = grad_input1.cuda(self.device)
grad_input2 = grad_input2.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c)

grad_input1 = grad_input1.transpose(2,3).transpose(1,2)
grad_input2 = grad_input2.transpose(2,3).transpose(1,2)

return grad_input1, grad_input2
9 changes: 6 additions & 3 deletions script/modules/stn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from torch.nn.modules.module import Module
from functions.stn import STNFunction
from functions.stn import STNFunction, STNFunctionBCHW

class STN(Module):
def __init__(self):
def __init__(self, layout = 'BHWD'):
super(STN, self).__init__()
self.f = STNFunction()
if layout == 'BHWD':
self.f = STNFunction()
else:
self.f = STNFunctionBCHW()
def forward(self, input1, input2):
return self.f(input1, input2)
256 changes: 256 additions & 0 deletions script/src/my_lib.c
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,259 @@ int BilinearSamplerBHWD_updateGradInput(THFloatTensor *inputImages, THFloatTenso
return 1;
}


int BilinearSamplerBCHW_updateOutput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *output)
{

int batchsize = inputImages->size[0];
int inputImages_height = inputImages->size[2];
int inputImages_width = inputImages->size[3];

int output_height = output->size[2];
int output_width = output->size[3];
int inputImages_channels = inputImages->size[1];

int output_strideBatch = output->stride[0];
int output_strideHeight = output->stride[2];
int output_strideWidth = output->stride[3];
int output_strideChannel = output->stride[1];


int inputImages_strideBatch = inputImages->stride[0];
int inputImages_strideHeight = inputImages->stride[2];
int inputImages_strideWidth = inputImages->stride[3];
int inputImages_strideChannel = inputImages->stride[1];

int grids_strideBatch = grids->stride[0];
int grids_strideHeight = grids->stride[2];
int grids_strideWidth = grids->stride[3];
int grids_strideChannel = grids->stride[1];


real *inputImages_data, *output_data, *grids_data;
inputImages_data = THFloatTensor_data(inputImages);
output_data = THFloatTensor_data(output);
grids_data = THFloatTensor_data(grids);

int b, yOut, xOut;

for(b=0; b < batchsize; b++)
{
for(yOut=0; yOut < output_height; yOut++)
{
for(xOut=0; xOut < output_width; xOut++)
{
//read the grid

real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + grids_strideChannel];
real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];

// get the weights for interpolation
int yInTopLeft, xInTopLeft;
real yWeightTopLeft, xWeightTopLeft;

real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
xInTopLeft = floor(xcoord);
xWeightTopLeft = 1 - (xcoord - xInTopLeft);

real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
yInTopLeft = floor(ycoord);
yWeightTopLeft = 1 - (ycoord - yInTopLeft);



const int outAddress = output_strideBatch * b + output_strideHeight * yOut + output_strideWidth * xOut;
const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;

real v=0;
real inTopLeft=0;
real inTopRight=0;
real inBottomLeft=0;
real inBottomRight=0;

// we are careful with the boundaries
bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;

int t;
// interpolation happens here
for(t=0; t<inputImages_channels; t++)
{
if(topLeftIsIn) inTopLeft = inputImages_data[inTopLeftAddress + t * inputImages_strideChannel];
if(topRightIsIn) inTopRight = inputImages_data[inTopRightAddress + t * inputImages_strideChannel];
if(bottomLeftIsIn) inBottomLeft = inputImages_data[inBottomLeftAddress + t * inputImages_strideChannel];
if(bottomRightIsIn) inBottomRight = inputImages_data[inBottomRightAddress + t * inputImages_strideChannel];

v = xWeightTopLeft * yWeightTopLeft * inTopLeft
+ (1 - xWeightTopLeft) * yWeightTopLeft * inTopRight
+ xWeightTopLeft * (1 - yWeightTopLeft) * inBottomLeft
+ (1 - xWeightTopLeft) * (1 - yWeightTopLeft) * inBottomRight;

output_data[outAddress + t * output_strideChannel] = v;
}

}
}
}

return 1;
}



int BilinearSamplerBCHW_updateGradInput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *gradInputImages,
THFloatTensor *gradGrids, THFloatTensor *gradOutput)
{
bool onlyGrid=false;

int batchsize = inputImages->size[0];
int inputImages_height = inputImages->size[2];
int inputImages_width = inputImages->size[3];
int gradOutput_height = gradOutput->size[2];
int gradOutput_width = gradOutput->size[3];
int inputImages_channels = inputImages->size[1];

int gradOutput_strideBatch = gradOutput->stride[0];
int gradOutput_strideHeight = gradOutput->stride[2];
int gradOutput_strideWidth = gradOutput->stride[3];
int gradOutput_strideChannel = gradOutput->stride[1];

int inputImages_strideBatch = inputImages->stride[0];
int inputImages_strideHeight = inputImages->stride[2];
int inputImages_strideWidth = inputImages->stride[3];
int inputImages_strideChannel = inputImages->stride[1];


int gradInputImages_strideBatch = gradInputImages->stride[0];
int gradInputImages_strideHeight = gradInputImages->stride[2];
int gradInputImages_strideWidth = gradInputImages->stride[3];
int gradInputImages_strideChannel = gradInputImages->stride[1];

int grids_strideBatch = grids->stride[0];
int grids_strideHeight = grids->stride[2];
int grids_strideWidth = grids->stride[3];
int grids_strideChannel = grids->stride[1];

int gradGrids_strideBatch = gradGrids->stride[0];
int gradGrids_strideHeight = gradGrids->stride[2];
int gradGrids_strideWidth = gradGrids->stride[3];
int gradGrids_strideChannel = gradGrids->stride[1];

real *inputImages_data, *gradOutput_data, *grids_data, *gradGrids_data, *gradInputImages_data;
inputImages_data = THFloatTensor_data(inputImages);
gradOutput_data = THFloatTensor_data(gradOutput);
grids_data = THFloatTensor_data(grids);
gradGrids_data = THFloatTensor_data(gradGrids);
gradInputImages_data = THFloatTensor_data(gradInputImages);

int b, yOut, xOut;

for(b=0; b < batchsize; b++)
{
for(yOut=0; yOut < gradOutput_height; yOut++)
{
for(xOut=0; xOut < gradOutput_width; xOut++)
{
//read the grid
real xf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth + grids_strideChannel];
real yf = grids_data[b*grids_strideBatch + yOut*grids_strideHeight + xOut*grids_strideWidth];

// get the weights for interpolation
int yInTopLeft, xInTopLeft;
real yWeightTopLeft, xWeightTopLeft;

real xcoord = (xf + 1) * (inputImages_width - 1) / 2;
xInTopLeft = floor(xcoord);
xWeightTopLeft = 1 - (xcoord - xInTopLeft);

real ycoord = (yf + 1) * (inputImages_height - 1) / 2;
yInTopLeft = floor(ycoord);
yWeightTopLeft = 1 - (ycoord - yInTopLeft);


const int inTopLeftAddress = inputImages_strideBatch * b + inputImages_strideHeight * yInTopLeft + inputImages_strideWidth * xInTopLeft;
const int inTopRightAddress = inTopLeftAddress + inputImages_strideWidth;
const int inBottomLeftAddress = inTopLeftAddress + inputImages_strideHeight;
const int inBottomRightAddress = inBottomLeftAddress + inputImages_strideWidth;

const int gradInputImagesTopLeftAddress = gradInputImages_strideBatch * b + gradInputImages_strideHeight * yInTopLeft + gradInputImages_strideWidth * xInTopLeft;
const int gradInputImagesTopRightAddress = gradInputImagesTopLeftAddress + gradInputImages_strideWidth;
const int gradInputImagesBottomLeftAddress = gradInputImagesTopLeftAddress + gradInputImages_strideHeight;
const int gradInputImagesBottomRightAddress = gradInputImagesBottomLeftAddress + gradInputImages_strideWidth;

const int gradOutputAddress = gradOutput_strideBatch * b + gradOutput_strideHeight * yOut + gradOutput_strideWidth * xOut;

real topLeftDotProduct = 0;
real topRightDotProduct = 0;
real bottomLeftDotProduct = 0;
real bottomRightDotProduct = 0;

real v=0;
real inTopLeft=0;
real inTopRight=0;
real inBottomLeft=0;
real inBottomRight=0;

// we are careful with the boundaries
bool topLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
bool topRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft >= 0 && yInTopLeft <= inputImages_height-1;
bool bottomLeftIsIn = xInTopLeft >= 0 && xInTopLeft <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;
bool bottomRightIsIn = xInTopLeft+1 >= 0 && xInTopLeft+1 <= inputImages_width-1 && yInTopLeft+1 >= 0 && yInTopLeft+1 <= inputImages_height-1;

int t;

for(t=0; t<inputImages_channels; t++)
{
real gradOutValue = gradOutput_data[gradOutputAddress + t * gradOutput_strideChannel];
if(topLeftIsIn)
{
real inTopLeft = inputImages_data[inTopLeftAddress + t * inputImages_strideChannel];
topLeftDotProduct += inTopLeft * gradOutValue;
if(!onlyGrid) gradInputImages_data[gradInputImagesTopLeftAddress + t * gradInputImages_strideChannel] += xWeightTopLeft * yWeightTopLeft * gradOutValue;
}

if(topRightIsIn)
{
real inTopRight = inputImages_data[inTopRightAddress + t * inputImages_strideChannel];
topRightDotProduct += inTopRight * gradOutValue;
if(!onlyGrid) gradInputImages_data[gradInputImagesTopRightAddress + t * gradInputImages_strideChannel] += (1 - xWeightTopLeft) * yWeightTopLeft * gradOutValue;
}

if(bottomLeftIsIn)
{
real inBottomLeft = inputImages_data[inBottomLeftAddress + t * inputImages_strideChannel];
bottomLeftDotProduct += inBottomLeft * gradOutValue;
if(!onlyGrid) gradInputImages_data[gradInputImagesBottomLeftAddress + t * gradInputImages_strideChannel] += xWeightTopLeft * (1 - yWeightTopLeft) * gradOutValue;
}

if(bottomRightIsIn)
{
real inBottomRight = inputImages_data[inBottomRightAddress + t * inputImages_strideChannel];
bottomRightDotProduct += inBottomRight * gradOutValue;
if(!onlyGrid) gradInputImages_data[gradInputImagesBottomRightAddress + t * gradInputImages_strideChannel] += (1 - xWeightTopLeft) * (1 - yWeightTopLeft) * gradOutValue;
}
}

xf = - yWeightTopLeft * topLeftDotProduct + yWeightTopLeft * topRightDotProduct - (1-yWeightTopLeft) * bottomLeftDotProduct + (1-yWeightTopLeft) * bottomRightDotProduct;

yf = - xWeightTopLeft * topLeftDotProduct + xWeightTopLeft * bottomLeftDotProduct - (1-xWeightTopLeft) * topRightDotProduct + (1-xWeightTopLeft) * bottomRightDotProduct;


gradGrids_data[b*gradGrids_strideBatch + yOut*gradGrids_strideHeight + xOut*gradGrids_strideWidth + gradGrids_strideChannel] = xf * (inputImages_width-1) / 2;

gradGrids_data[b*gradGrids_strideBatch + yOut*gradGrids_strideHeight + xOut*gradGrids_strideWidth] = yf * (inputImages_height-1) / 2;


}
}
}

return 1;
}


7 changes: 7 additions & 0 deletions script/src/my_lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@ int BilinearSamplerBHWD_updateOutput(THFloatTensor *inputImages, THFloatTensor *

int BilinearSamplerBHWD_updateGradInput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *gradInputImages,
THFloatTensor *gradGrids, THFloatTensor *gradOutput);



int BilinearSamplerBCHW_updateOutput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *output);

int BilinearSamplerBCHW_updateGradInput(THFloatTensor *inputImages, THFloatTensor *grids, THFloatTensor *gradInputImages,
THFloatTensor *gradGrids, THFloatTensor *gradOutput);
4 changes: 4 additions & 0 deletions script/src/my_lib_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,7 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages,
}
return 1;
}




Loading

0 comments on commit 4543489

Please sign in to comment.