-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
146 lines (110 loc) · 4.22 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import numpy as np
import torch
import sys
from torchinfo import summary
import os
import imageio
from model import make_model
# Make sure loading .exr works for imageio
try:
imageio.plugins.freeimage.download()
except FileExistsError:
# Ignore
pass
def tensor_mem_size_in_bytes(x):
return sys.getsizeof(x.storage())
def load_trained_model(model_config, weights_path, device, mesh=None):
model = make_model(model_config, mesh=mesh)
data = torch.load(weights_path)
if "model_state_dict" in data:
model.load_state_dict(data["model_state_dict"])
else:
model.load_state_dict(data)
return model.to(device)
def load_cameras(view_path):
cameras = np.load(os.path.join(view_path, "depth", "cameras.npz"))
camCv2world = torch.from_numpy(cameras["world_mat_0"]).to(dtype=torch.float32)
K = torch.from_numpy(cameras["camera_mat_0"]).to(dtype=torch.float32)
return camCv2world, K
def model_summary(model, data):
data_batch = next(iter(data["train"]))
summary(model, input_data=[data_batch])
def load_obj_mask_as_tensor(view_path):
if view_path.endswith(".npy"):
return np.load(view_path)
depth_path = os.path.join(view_path, "depth", "depth_0000.exr")
if os.path.exists(depth_path):
depth_map = imageio.imread(depth_path)[..., 0]
mask_value = 1.e+10
obj_mask = depth_map != mask_value
else:
mask_path = os.path.join(view_path, "depth", "mask.png")
assert os.path.exists(mask_path), "Must have depth or mask"
mask = imageio.imread(mask_path)
obj_mask = mask != 0 # 0 is invalid
obj_mask = torch.from_numpy(obj_mask)
return obj_mask
def load_depth_as_numpy(view_path):
depth_path = os.path.join(view_path, "depth", "depth_0000.exr")
assert os.path.exists(depth_path)
depth_map = imageio.imread(depth_path)[..., 0]
return depth_map
def batchify_dict_data(data_dict, input_total_size, batch_size):
idxs = np.arange(0, input_total_size)
batch_idxs = np.split(idxs, np.arange(batch_size, input_total_size, batch_size), axis=0)
batches = []
for cur_idxs in batch_idxs:
data = {}
for key in data_dict.keys():
data[key] = data_dict[key][cur_idxs]
batches.append(data)
return batches
##########################################################################################
# The following is taken from:
# https://github.com/tum-vision/tandem/blob/master/cva_mvsnet/utils.py
##########################################################################################
# convert a function into recursive style to handle nested dict/list/tuple variables
def make_recursive_func(func):
def wrapper(vars, **kwargs):
if isinstance(vars, list):
return [wrapper(x, **kwargs) for x in vars]
elif isinstance(vars, tuple):
return tuple([wrapper(x, **kwargs) for x in vars])
elif isinstance(vars, dict):
return {k: wrapper(v, **kwargs) for k, v in vars.items()}
else:
return func(vars, **kwargs)
return wrapper
@make_recursive_func
def tensor2float(vars):
if isinstance(vars, float):
return vars
elif isinstance(vars, torch.Tensor):
return vars.data.item()
else:
raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars)))
@make_recursive_func
def tensor2numpy(vars):
if isinstance(vars, np.ndarray):
return vars
elif isinstance(vars, torch.Tensor):
return vars.detach().cpu().numpy().copy()
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
@make_recursive_func
def tocuda(vars):
if isinstance(vars, torch.Tensor):
return vars.to(torch.device("cuda"))
elif isinstance(vars, str):
return vars
else:
raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars)))
@make_recursive_func
def to_device(x, *, device):
if torch.is_tensor(x):
return x.to(device)
elif isinstance(x, str):
return x
else:
raise NotImplementedError(f"Invalid type for to_device: {type(x)}")
##########################################################################################