diff --git a/script/functions/stn.py b/script/functions/stn.py index 40d5f29..ea2c693 100644 --- a/script/functions/stn.py +++ b/script/functions/stn.py @@ -48,10 +48,10 @@ def forward(self, input1, input2): my_lib.BilinearSamplerBCHW_updateOutput(input1, input2, output) else: - output = output.transpose(1,2).transpose(2,3) - input1 = input1.transpose(1,2).transpose(2,3) - input2 = input2.transpose(1,2).transpose(2,3) - + output = output.transpose(1,2).transpose(2,3).contiguous() + input1 = input1.transpose(1,2).transpose(2,3).contiguous() + input2 = input2.transpose(1,2).transpose(2,3).contiguous() + #print(output.size(), input1.size(), input2.size()) output = output.cuda(self.device) my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c) output = output.transpose(2,3).transpose(1,2) @@ -65,9 +65,9 @@ def backward(self, grad_output): if not grad_output.is_cuda: my_lib.BilinearSamplerBCHW_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output) else: - grad_input1 = grad_input1.transpose(1,2).transpose(2,3) - grad_input2 = grad_input2.transpose(1,2).transpose(2,3) - grad_output = grad_output.transpose(1,2).transpose(2,3) + grad_input1 = grad_input1.transpose(1,2).transpose(2,3).contiguous() + grad_input2 = grad_input2.transpose(1,2).transpose(2,3).contiguous() + grad_output = grad_output.transpose(1,2).transpose(2,3).contiguous() grad_input1 = grad_input1.cuda(self.device) grad_input2 = grad_input2.cuda(self.device) diff --git a/script/test.py b/script/test.py index e28bacd..6544b18 100644 --- a/script/test.py +++ b/script/test.py @@ -62,7 +62,7 @@ out.backward(input1.data) print(input1.grad.size(), 'time:', time.time() - start) -with torch.cuda.device(3): +with torch.cuda.device(1): input1 = input1.cuda() input2 = input2.cuda() start = time.time()