Skip to content

Commit fcdec95

Browse files
authored
Create infer.py
1 parent fd32fa2 commit fcdec95

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

infer.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
import os
3+
import sys
4+
import numpy
5+
import cv2
6+
7+
import torch
8+
9+
import models
10+
import utils
11+
import exporters
12+
13+
def parse_arguments(args):
14+
usage_text = (
15+
"Semi-supervised Spherical Depth Estimation Testing."
16+
)
17+
parser = argparse.ArgumentParser(description=usage_text)
18+
parser.add_argument("--input_path", type=str, help="Path to the input spherical panorama image.")
19+
parser.add_argument('--weights', type=str, help='Path to the trained weights file.')
20+
parser.add_argument('-g','--gpu', type=str, default='0', help='The ids of the GPU(s) that will be utilized. (e.g. 0 or 0,1, or 0,2). Use -1 for CPU.')
21+
return parser.parse_known_args(args)
22+
23+
if __name__ == "__main__":
24+
args, unknown = parse_arguments(sys.argv)
25+
gpus = [int(id) for id in args.gpu.split(',') if int(id) >= 0]
26+
# device & visualizers
27+
device = torch.device("cuda:{}" .format(gpus[0])\
28+
if torch.cuda.is_available() and len(gpus) > 0 and gpus[0] >= 0\
29+
else "cpu")
30+
# model
31+
model = models.get_model("resnet_coord", {})
32+
utils.init.initialize_weights(model, args.weights, pred_bias=None)
33+
model = model.to(device)
34+
# test data
35+
width, height = 512, 256
36+
if not os.path.exists(args.input_path):
37+
print("Input image path does not exist (%s)." % args.input_path)
38+
exit(-1)
39+
img = cv2.imread(args.input_path)
40+
h, w, _ = img.shape
41+
if h != height and w != width:
42+
img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
43+
img = img.transpose(2, 0, 1) / 255.0
44+
img = torch.from_numpy(img).float().expand(1, -1, -1, -1)
45+
model.eval()
46+
with torch.no_grad():
47+
left_rgb = img.to(device)
48+
''' Prediction '''
49+
left_depth_pred = torch.abs(model(left_rgb))
50+
exporters.image.save_data(os.path.join(
51+
os.path.dirname(args.input_path),
52+
os.path.splitext(os.path.basename(
53+
args.input_path))[0] + "_depth.exr"),
54+
left_depth_pred, scale=1.0)

0 commit comments

Comments
 (0)