-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmesh_refiner.py
101 lines (83 loc) · 4.59 KB
/
mesh_refiner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.optim as optim
from pytorch3d.renderer import look_at_view_transform
from tqdm.autonotebook import tqdm
import pandas as pd
from models.deformation_net_graph_convolutional_lite import DeformationNetworkGraphConvolutionalLite
from utils.forward_pass import batched_forward_pass
class MeshRefiner():
def __init__(self, cfg, device):
"""Class to facilitate refining meshes.
Args:
cfg (dict): Config dictionary.
device (torch.device): PyTorch device to perform computations on.
"""
self.cfg = cfg
self.device = device
self.num_iterations = self.cfg["refinement"]["num_iterations"]
self.lr = self.cfg["refinement"]["learning_rate"]
self.img_sym_num_azim = self.cfg["training"]["img_sym_num_azim"]
self.sil_lam = self.cfg["training"]["sil_lam"]
self.l2_lam = self.cfg["training"]["l2_lam"]
self.lap_lam = self.cfg["training"]["lap_smoothness_lam"]
self.normals_lam = self.cfg["training"]["normal_consistency_lam"]
self.img_sym_lam = self.cfg["training"]["img_sym_lam"]
self.vertex_sym_lam = self.cfg["training"]["vertex_sym_lam"]
def refine_mesh(self, mesh, rgba_image, R, T, record_debug=False):
"""Performs refinement on a mesh.
Args:
mesh (Mesh): PyTorch3D mesh object.
rgba_image (torch.tensor): Corresponding rgba image for the mesh.
R (torch.tensor): Rotation matrix for camera.
T (torch.tensor): Translation matrix for camera
record_debug (bool, optional): If intermediate results should be saved for debugging. Defaults to False.
Returns:
mesh: Refined mesh.
dict: Dictionary with information during refinement, including loss information during training.
"""
# prep inputs used during training
image = rgba_image[:,:,:3]
image_in = torch.unsqueeze(torch.tensor(image/255, dtype=torch.float).permute(2,0,1),0).to(self.device)
mask = rgba_image[:,:,3] > 0
mask_gt = torch.unsqueeze(torch.tensor(mask, dtype=torch.float), 0).to(self.device)
verts_in = torch.unsqueeze(mesh.verts_packed(),0).to(self.device)
R = R.to(self.device)
T = T.to(self.device)
deform_net_input = {"mesh_verts": verts_in, "image":image_in, "R": R, "T":T, "mesh": mesh, "mask": mask_gt}
# setting up
deform_net = DeformationNetworkGraphConvolutionalLite(self.cfg, self.device)
deform_net.to(self.device)
optimizer = optim.Adam(deform_net.parameters(), lr=self.lr)
loss_info = pd.DataFrame()
lowest_loss = None
best_deformed_mesh = None
best_refinement_info = {}
# starting REFINEment
for i in tqdm(range(self.num_iterations)):
# forward pass
deform_net.train()
optimizer.zero_grad()
loss_dict, deformed_mesh, forward_pass_info = batched_forward_pass(self.cfg, self.device, deform_net, deform_net_input)
# optimization step on weighted losses
total_loss = sum([loss_dict[loss_name] * self.cfg['training'][loss_name.replace("loss", "lam")] for loss_name in loss_dict])
total_loss.backward()
optimizer.step()
# saving info
curr_train_info = {"iteration": i, "total_loss": total_loss.item()}
curr_train_info = {**curr_train_info, **{loss_name:loss_dict[loss_name].item() for loss_name in loss_dict}}
loss_info = loss_info.append(curr_train_info, ignore_index=True)
if lowest_loss is None or total_loss.item() < lowest_loss:
lowest_loss = total_loss.item()
best_deformed_mesh = deformed_mesh
if record_debug:
best_refinement_info = forward_pass_info
best_refinement_info["loss_info"] = loss_info
# moving refinement info to cpu
if "asym_conf_scores" in best_refinement_info:
best_refinement_info["asym_conf_scores"] = best_refinement_info["asym_conf_scores"].detach().cpu()
if "img_sym_loss_debug_imgs" in best_refinement_info:
for i in range(len(best_refinement_info["img_sym_loss_debug_imgs"])):
for j in range(len(best_refinement_info["img_sym_loss_debug_imgs"][i])):
for k in range(len(best_refinement_info["img_sym_loss_debug_imgs"][i][j])):
best_refinement_info["img_sym_loss_debug_imgs"][i][j][k] = best_refinement_info["img_sym_loss_debug_imgs"][i][j][k].detach().cpu()
return best_deformed_mesh, best_refinement_info