forked from xy-guo/GwcNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
save_disp.py
82 lines (68 loc) · 2.86 KB
/
save_disp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from __future__ import print_function, division
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import time
from tensorboardX import SummaryWriter
from datasets import __datasets__
from models import __models__
from utils import *
from torch.utils.data import DataLoader
import gc
import skimage
cudnn.benchmark = True
parser = argparse.ArgumentParser(description='Group-wise Correlation Stereo Network (GwcNet)')
parser.add_argument('--model', default='gwcnet-g', help='select a model structure', choices=__models__.keys())
parser.add_argument('--maxdisp', type=int, default=192, help='maximum disparity')
parser.add_argument('--dataset', default='kitti', help='dataset name', choices=__datasets__.keys())
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--testlist', required=True, help='testing list')
parser.add_argument('--loadckpt', required=True, help='load the weights from a specific checkpoint')
# parse arguments
args = parser.parse_args()
# dataset, dataloader
StereoDataset = __datasets__[args.dataset]
test_dataset = StereoDataset(args.datapath, args.testlist, False)
TestImgLoader = DataLoader(test_dataset, 1, shuffle=False, num_workers=4, drop_last=False)
# model, optimizer
model = __models__[args.model](args.maxdisp)
model = nn.DataParallel(model)
model.cuda()
# load parameters
print("loading model {}".format(args.loadckpt))
state_dict = torch.load(args.loadckpt)
model.load_state_dict(state_dict['model'])
def test():
os.makedirs('./predictions', exist_ok=True)
for batch_idx, sample in enumerate(TestImgLoader):
start_time = time.time()
disp_est_np = tensor2numpy(test_sample(sample))
top_pad_np = tensor2numpy(sample["top_pad"])
right_pad_np = tensor2numpy(sample["right_pad"])
left_filenames = sample["left_filename"]
print('Iter {}/{}, time = {:3f}'.format(batch_idx, len(TestImgLoader),
time.time() - start_time))
for disp_est, top_pad, right_pad, fn in zip(disp_est_np, top_pad_np, right_pad_np, left_filenames):
assert len(disp_est.shape) == 2
disp_est = np.array(disp_est[top_pad:, :-right_pad], dtype=np.float32)
fn = os.path.join("predictions", fn.split('/')[-1])
print("saving to", fn, disp_est.shape)
disp_est_uint = np.round(disp_est * 256).astype(np.uint16)
skimage.io.imsave(fn, disp_est_uint)
# test one sample
@make_nograd_func
def test_sample(sample):
model.eval()
disp_ests = model(sample['left'].cuda(), sample['right'].cuda())
return disp_ests[-1]
if __name__ == '__main__':
test()