diff --git a/.gitignore b/.gitignore index e78a767..5d5c812 100755 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ _ext *.o work work/* -_ext/ \ No newline at end of file +_ext/ +model_weights \ No newline at end of file diff --git a/install.sh b/install.sh index 538b02e..264bfa4 100755 --- a/install.sh +++ b/install.sh @@ -11,4 +11,8 @@ cd ../channelnorm_package rm -rf *_cuda.egg-info build dist __pycache__ python3 setup.py install --user +cd ../correlation_cpp_package +rm -rf *_cuda.egg-info build dist __pycache__ +python3 setup.py install --user + cd .. diff --git a/models.py b/models.py index 8457d2f..4d24387 100755 --- a/models.py +++ b/models.py @@ -8,6 +8,11 @@ try: from networks.resample2d_package.resample2d import Resample2d from networks.channelnorm_package.channelnorm import ChannelNorm + # PyTorch versions + # To use the CPU implementation of Resample2D and Channelnorm uncomment the + # two lines below and comment the two lines above. + # from networks.channelnorm import ChannelNorm + # from networks.resample2d import Resample2d from networks import FlowNetC from networks import FlowNetS @@ -18,7 +23,12 @@ except: from .networks.resample2d_package.resample2d import Resample2d from .networks.channelnorm_package.channelnorm import ChannelNorm - + # PyTorch versions + # To use the CPU implementation of Resample2D and Channelnorm uncomment the + # two lines below and comment the two lines above. + # from .networks.channelnorm import ChannelNorm + # from .networks.resample2d import Resample2d + from .networks import FlowNetC from .networks import FlowNetS from .networks import FlowNetSD diff --git a/networks/FlowNetC.py b/networks/FlowNetC.py index 61e117a..e9628b3 100755 --- a/networks/FlowNetC.py +++ b/networks/FlowNetC.py @@ -6,6 +6,11 @@ import numpy as np from .correlation_package.correlation import Correlation +# To use CPU implementation of correlation in C++ comment line above and uncomment +# the line below. +# from .correlation_cpp_package.correlation import Correlation +# PyTorch Version +# from .correlation import Correlation from .submodules import * 'Parameter count , 39,175,298 ' diff --git a/networks/channelnorm.py b/networks/channelnorm.py new file mode 100755 index 0000000..4aff992 --- /dev/null +++ b/networks/channelnorm.py @@ -0,0 +1,39 @@ +import torch +from torch.autograd import Function, Variable +from torch.nn.modules.module import Module +# import channelnorm_cpp + +class ChannelNormFunction(Function): + + @staticmethod + def forward(ctx, input1, norm_deg=2): + assert input1.is_contiguous() + b, c, h, w = input1.size() + output = input1.new(b, 1, h, w).zero_() + output = torch.pow(input1, norm_deg) + output = torch.sqrt( torch.sum(output, dim=1)) + ctx.save_for_backward(input1, output) + ctx.norm_deg = norm_deg + return output.unsqueeze(0) + + @staticmethod + def backward(ctx, grad_output): + """ + Not Implemented! + """ + input1, output = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + + return grad_input1, None + + +class ChannelNorm(Module): + + def __init__(self, norm_deg=2): + super(ChannelNorm, self).__init__() + self.norm_deg = norm_deg + + def forward(self, input1): + return ChannelNormFunction.apply(input1, self.norm_deg) + diff --git a/networks/correlation.py b/networks/correlation.py new file mode 100644 index 0000000..9966400 --- /dev/null +++ b/networks/correlation.py @@ -0,0 +1,121 @@ +""" +Copyright 2020 Samim Taray + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +""" + +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function +from torch.nn import ZeroPad2d +# import correlation_cuda +import code + +class CorrelationFunction(Function): + + def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): + super(CorrelationFunction, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def extractwindow(self, f2pad, i, j): + hindex = torch.tensor( range(i,i+(2*self.max_displacement)+1, self.stride2) ) + windex = torch.tensor( range(j,j+(2*self.max_displacement)+1, self.stride2) ) + # Advanced indexing logic. Ref: https://github.com/pytorch/pytorch/issues/1080 + # the way advance indexing works: + # ---> f2pad[:, :, hindex] chose value at f2pad at hindex location, then + # ---> appending [:, :, :, windex] to it only choses values at windex. + # ---> Thus it choses value at the alternative location of f2pad + # win = f2pad[:,:, i:i+(2*self.max_displacement)+1, j:j+(2*self.max_displacement)+1] + + win = f2pad[:, :, hindex][:, :, :, windex] + return win + + def forward(self, f1, f2): + self.save_for_backward(f1, f2) + f1b = f1.shape[0] #batch + f1c = f1.shape[1] #channel + f1h = f1.shape[2] #height + f1w = f1.shape[3] #width + + f2b = f2.shape[0] #batch + f2c = f2.shape[1] #channel + f2h = f2.shape[2] #height + f2w = f2.shape[3] #width + + # generate padded f2 + padder = ZeroPad2d(self.pad_size) + f2pad = padder(f2) + + # Define output shape and initialize it + outc = (2*(self.max_displacement/self.stride2)+1) * (2*(self.max_displacement/self.stride2)+1) + outc = int(outc) # number of output channel + outb = f1b # size of output batch + outh = f1h # size of output height + outw = f1w # size of output width + output = torch.ones((outb, outc, outh, outw)) + # this gives device type + output = output.to(f1.device) + + for i in range(f1h): + for j in range(f1w): + # Extract window W around i,j from f2pad of size (1X256X21X21) + win = self.extractwindow(f2pad, i, j) + # Extract kernel: size [1, 256, 1, 1] + k = f1[:, :, i, j].unsqueeze(2).unsqueeze(3) + # boradcasting multiplication along channel dimension + # it multiplies all the 256 element of k to win and keep the result as it is + # size of mult: 1, 256, 21, 21 + mult = win * k + # Sum along channel dimension to get dot product. size 1X21X21 + inner_prod = torch.sum(mult, dim = 1) + + # Flatten last 2 dimensions h,w to one dimension of h*w = no of channels in output + # size 1X1X1X441 + inner_prod = inner_prod.flatten(-2, -1).unsqueeze(1).unsqueeze(1) + output[:, :, i, j] = inner_prod + # return the average + return output/f1c + + def backward(self, grad_output): + """ + Not Implemented! + """ + input1, input2 = self.saved_tensors + with torch.cuda.device_of(input1): + rbot1 = input1.new() + rbot2 = input2.new() + + grad_input1 = input1.new() + grad_input2 = input2.new() + + correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, + self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) + + return grad_input1, grad_input2 + + +class Correlation(Module): + def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + + result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) + + return result + diff --git a/networks/correlation_cpp_package/.gitignore b/networks/correlation_cpp_package/.gitignore new file mode 100644 index 0000000..50180ed --- /dev/null +++ b/networks/correlation_cpp_package/.gitignore @@ -0,0 +1,145 @@ +*.pyc +.torch +_ext +*.o +work +work/* +_ext/ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ \ No newline at end of file diff --git a/networks/correlation_cpp_package/correlation.cpp b/networks/correlation_cpp_package/correlation.cpp new file mode 100644 index 0000000..965a78f --- /dev/null +++ b/networks/correlation_cpp_package/correlation.cpp @@ -0,0 +1,125 @@ +/* +Copyright 2020 Samim Taray + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +*/ +#include +using namespace torch; + +#include + +#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < W && y >= 0 && y < H) +#define WITHIN_BOUNDS3(val1, val2, bound1, bound2) (val1 >= 0 && val1 < bound1 && val2 >= 0 && val2 < bound2) +#define WITHIN_BOUNDS2(x, bound) (x >=0 && x < bound) + + +template +static void correlate_patch( + TensorAccessor f1_acc, + TensorAccessor f2_acc, + TensorAccessor out_acc, + int n, /* batch number */ + int h, /* height cordinate number */ + int w, /* width cordinate number */ + int pad_size, + int kernel_size, + int max_displacement, + int stride1, + int stride2 +){ + /* + Algorithm: we are in the h,w position of both feature maps. Where do we go from here? + Let's see. + */ + int f1c = f1_acc.size(0); + int f1h = f1_acc.size(1); + int f1w = f1_acc.size(2); + + /* Indicies that define the extents of the window. */ + int win_starth, win_endh, win_startw, win_endw; + win_starth = h - max_displacement; + win_endh = h + max_displacement + 1; + win_startw = w - max_displacement; + win_endw = w + max_displacement + 1; + + int c, ph, pw, outpc = 0; + + for ( ph = win_starth; ph < win_endh; ph+=stride2){ + for ( pw = win_startw; pw < win_endw; pw+=stride2){ + if ( WITHIN_BOUNDS3(ph, pw, f1h, f1w /* better to have f2 here maybe */)){ + // We are in the window now. + scalar_t outval = 0.0; + for (c = 0; c < f1c; c++){ + outval += f1_acc[c][h][w] * f2_acc[c][ph][pw]; + } + // TODO: Optimization: We can get this from ph and pw. This should be ph * (win_endh - win_starth) + // outpc = (( ph - win_starth)/stride2) * max_displacement + (pw - win_startw)/stride2; // Output channel index. + // std::cout<<"outpc = "<(); + auto f2_acc = f2.accessor(); + auto out_acc = output.accessor(); + + correlate_patch( + f1_acc[n], + f2_acc[n], + out_acc[n], + n, + h, + w, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2); + })); + } + } + } + return output; + } + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &correlation_cpp_forward, "Spatial Correlation Sampler Forward"); +// m.def("backward", &correlation_cpp_backward, "Spatial Correlation Sampler backward"); +} \ No newline at end of file diff --git a/networks/correlation_cpp_package/correlation.py b/networks/correlation_cpp_package/correlation.py new file mode 100644 index 0000000..289f4f1 --- /dev/null +++ b/networks/correlation_cpp_package/correlation.py @@ -0,0 +1,31 @@ +import torch +import correlation + +class CorrelationFunction(torch.autograd.Function): + + @staticmethod + def forward(ctx, f1, f2, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2): + output = correlation.forward(f1, f2, pad_size, kernel_size, max_displacement, stride1, stride2) + # ctx.save_for_backward(output) + return output + + @staticmethod + def backward(ctx): + """ + Not Implemented! + """ + output = None + +class Correlation(torch.nn.Module): + def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): + super(Correlation, self).__init__() + self.pad_size = pad_size + self.kernel_size = kernel_size + self.max_displacement = max_displacement + self.stride1 = stride1 + self.stride2 = stride2 + self.corr_multiply = corr_multiply + + def forward(self, input1, input2): + result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2) + return result diff --git a/networks/correlation_cpp_package/setup.py b/networks/correlation_cpp_package/setup.py new file mode 100644 index 0000000..cbefe98 --- /dev/null +++ b/networks/correlation_cpp_package/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup, Extension +from torch.utils import cpp_extension + +setup( + name='correlation', + ext_modules=[ + cpp_extension.CppExtension( + 'correlation', # Name of the module used in pybind + ['correlation.cpp'], # source files + extra_compile_args={'cxx': ['-fopenmp']}, + extra_link_args=['-lgomp']) + ], + author='Samim Zahoor Taray', + author_email='zsameem@gmail.com', + install_requires=['torch>=1.1', 'numpy'], + cmdclass={ + 'build_ext': cpp_extension.BuildExtension + } +) \ No newline at end of file diff --git a/networks/resample2d.py b/networks/resample2d.py new file mode 100755 index 0000000..3ea4e60 --- /dev/null +++ b/networks/resample2d.py @@ -0,0 +1,135 @@ +""" +Copyright 2020 Samim Taray + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 +""" + +import torch +from torch.nn.modules.module import Module +from torch.autograd import Function, Variable + + +class Resample2dFunction(Function): + + @staticmethod + def forward(ctx, input1, input2, kernel_size=1): + assert input1.is_contiguous() + assert input2.is_contiguous() + ctx.save_for_backward(input1, input2) + ctx.kernel_size = kernel_size + _, d, _, _ = input1.size() + b, _, h, w = input2.size() + output = input1.clone().detach() + # naive loop implementation from original flownet works well + image_data = input1 + warped_data = output + + # for x in range(w): + # for y in range(h): + + # fx = input2[0, 0, y, x] + # fy = input2[0, 1, y, x] + + # x2 = x + fx + # y2 = y + fy + + # if x2>=0 and y2>=0 and x2< w and y2 < h: + + # ix2_L = int(x2) + # iy2_T = int(y2) + # ix2_R = min(ix2_L+1, w-1) + # iy2_B = min(iy2_T+1, h-1) + + # alpha=x2-ix2_L + # beta=y2-iy2_T + + # for c in range(3): + # TL = image_data[:, c, iy2_T, ix2_L] + # TR = image_data[:, c, iy2_T, ix2_R] + # BL = image_data[:, c, iy2_B, ix2_L] + # BR = image_data[:, c, iy2_B, ix2_R] + + # warped_data[:, c, y, x] = \ + # (1-alpha)*(1-beta)*TL + \ + # alpha*(1-beta)*TR + \ + # (1-alpha)*beta*BL + \ + # alpha*beta*BR + + # else: + # for c in range(3): + # warped_data[:, c, y, x] = 0.0 + + # Vectorized implementation + for batch in range(b): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) + x = x.to(input1.device) + y = y.to(input1.device) + fx = input2[batch, 0, y, x] + fy = input2[batch, 1, y, x] + + x2 = x.float() + fx + y2 = y.float() + fy + + ix2_L = x2.long() + iy2_T = y2.long() + ix2_L = torch.clamp(ix2_L, 0, w-1) + iy2_T = torch.clamp(iy2_T, 0, h-1) + ix2_R = torch.clamp(ix2_L + 1, 0, w-1) + iy2_B = torch.clamp(iy2_T + 1, 0, h-1) + + alpha = x2-ix2_L.float() + beta = y2-iy2_T.float() + # for c in range(3): + # TL = image_data[:, c, iy2_T, ix2_L] + # TR = image_data[:, c, iy2_T, ix2_R] + # BL = image_data[:, c, iy2_B, ix2_L] + # BR = image_data[:, c, iy2_B, ix2_R] + + # warped_data[:, c, :, :] = \ + # (1-alpha)*(1-beta)*TL + \ + # alpha*(1-beta)*TR + \ + # (1-alpha)*beta*BL + \ + # alpha*beta*BR + + TL = image_data[batch, :, iy2_T, ix2_L] + TR = image_data[batch, :, iy2_T, ix2_R] + BL = image_data[batch, :, iy2_B, ix2_L] + BR = image_data[batch, :, iy2_B, ix2_R] + #Interpolation + warped_data[batch, :, :, :] = (1-alpha) * (1-beta) * TL + \ + alpha * (1-beta) * TR + \ + (1-alpha) * beta * BL + \ + alpha * beta * BR + + + return warped_data + + @staticmethod + def backward(ctx, grad_output): + assert grad_output.is_contiguous() + + input1, input2 = ctx.saved_tensors + + grad_input1 = Variable(input1.new(input1.size()).zero_()) + grad_input2 = Variable(input1.new(input2.size()).zero_()) + + resample2d_cuda.backward(input1, input2, grad_output.data, + grad_input1.data, grad_input2.data, + ctx.kernel_size) + + return grad_input1, grad_input2, None + + +class Resample2d(Module): + + def __init__(self, kernel_size=1): + super(Resample2d, self).__init__() + self.kernel_size = kernel_size + + def forward(self, input1, input2): + input1_c = input1.contiguous() + return Resample2dFunction.apply(input1_c, input2, self.kernel_size) diff --git a/run_a_pair.py b/run_a_pair.py deleted file mode 100644 index 0e6aea2..0000000 --- a/run_a_pair.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import numpy as np -import argparse - -from Networks.FlowNet2 import FlowNet2 # the path is depended on where you create this module -from frame_utils import read_gen # the path is depended on where you create this module - -if __name__ == '__main__': - # obtain the necessary args for construct the flownet framework - parser = argparse.ArgumentParser() - parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') - parser.add_argument("--rgb_max", type=float, default=255.) - - args = parser.parse_args() - - # initial a Net - net = FlowNet2(args).cuda() - # load the state_dict - dict = torch.load("/home/hjj/PycharmProjects/flownet2_pytorch/FlowNet2_checkpoint.pth.tar") - net.load_state_dict(dict["state_dict"]) - - # load the image pair, you can find this operation in dataset.py - pim1 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img0.ppm") - pim2 = read_gen("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img1.ppm") - images = [pim1, pim2] - images = np.array(images).transpose(3, 0, 1, 2) - im = torch.from_numpy(images.astype(np.float32)).unsqueeze(0).cuda() - - # process the image pair to obtian the flow - result = net(im).squeeze() - - - # save flow, I reference the code in scripts/run-flownet.py in flownet2-caffe project - def writeFlow(name, flow): - f = open(name, 'wb') - f.write('PIEH'.encode('utf-8')) - np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) - flow = flow.astype(np.float32) - flow.tofile(f) - f.flush() - f.close() - - - data = result.data.cpu().numpy().transpose(1, 2, 0) - writeFlow("/home/hjj/flownet2-master/data/FlyingChairs_examples/0000007-img.flo", data) diff --git a/test.py b/test.py new file mode 100644 index 0000000..fc90110 --- /dev/null +++ b/test.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import numpy as np +import os +import models +import time +import argparse, os, sys, subprocess +import matplotlib.pyplot as plt +from PIL import Image +from utils import flow_utils + +""" +Test functionality by running flownet-2 on a pair of images provided in +./test_images Download the FN2 weight file and put it in ./model_weights +dir (or create) a new dir and pass it as argument. +""" +def load_sequence(img1, img2): + # Load images + leftimage = Image.open(img1) + rightimage = Image.open(img2) + + images = [leftimage, rightimage] + + frame_size = images[0].size[:2] + # resize images to the nearest multiple of 64 + render_size = list(frame_size) + render_size[0] = ((frame_size[0]) // 64) * 64 + render_size[1] = ((frame_size[1]) // 64) * 64 + for i in range(len(images)): + images[i] = images[i].resize((render_size[0], render_size[1])) + images[i] = np.asarray(images[i]).astype(np.float32) + + images_tensor = np.array(images).transpose(3, 0, 1, 2) + images_tensor = torch.from_numpy(images_tensor.astype(np.float32)) + # Add a dimension in the beginning for batch size + images_tensor = images_tensor.unsqueeze(0) + return images_tensor, images + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--rgb_max", type=float, default = 255.) + parser.add_argument('--weight_file', default='model_weights/FlowNet2_checkpoint.pth.tar', type=str, + metavar='PATH', help='path to latest checkpoint (default: none)') + parser.add_argument('--img1', default='test_images/frame_0003.png', type=str, + metavar='PATH', help='path to left stereo image') + parser.add_argument('--img2', default='test_images/frame_0006.png', type=str, + metavar='PATH', help='path to right stereo image') + parser.add_argument('--mode', + default='torch', + const='torch', + nargs='?', + choices=['cuda', 'cpp', 'torch'], + help='mode cuda, cpp, or torch (default: %(default)s)') + + # These arguments are used within the nvidia flownet models. + parser.add_argument('--fp16', action='store_true', help='Run model in pseudo-fp16 mode (fp16 storage fp32 math).') + parser.add_argument('--fp16_scale', type=float, default=1024., + help='Loss scaling, positive power of 2 values can improve fp16 convergence.') + + + args = parser.parse_args() + print ("Running in {} mode".format (args.mode)) + + net = models.FlowNet2(args) + checkpoint = torch.load(args.weight_file) + net.load_state_dict(checkpoint['state_dict']) + net.eval() + if args.mode == 'cuda': + net.cuda() + + images, image_list = load_sequence(args.img1, args.img2) + start_time = time.time() + if args.mode == 'cuda': + images = images.cuda() + output = net(images) + disps = output[0].data.cpu().numpy().transpose(1, 2, 0) + end_time = time.time() + print ("Inference took {:0.4f} seconds".format(end_time-start_time)) + # save the visualization of one disparity + flow = flow_utils.flow2img (disps) + plt.imshow(flow) + plt.show() + diff --git a/test_images/frame_0003.png b/test_images/frame_0003.png new file mode 100755 index 0000000..8bc7c37 Binary files /dev/null and b/test_images/frame_0003.png differ diff --git a/test_images/frame_0006.png b/test_images/frame_0006.png new file mode 100755 index 0000000..30963ba Binary files /dev/null and b/test_images/frame_0006.png differ