-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutil.py
87 lines (77 loc) · 2.91 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torchvision
import time
import torch.nn.functional as F
import torch.nn as nn
import torch
import numbers
import math
def save_batch_img(imgs,p,t= 8):
### t c h w
imgs = torchvision.utils.make_grid(imgs,nrow=t,normalize =True)
torchvision.utils.save_image(imgs,p)
def save_attention_mask(src_input,vis_probs,path,time_inter = 10000):
# time_inter = 1
## C T H W ## N T H W
c,t,h,w = src_input.size()
if ("0" in str(src_input.device) or "cpu" in str(src_input.device)) and int(time.time())%time_inter == 0:
src_input_1st = src_input[:,:,:,:].transpose(0,1)
vis = [src_input_1st]
if vis_probs:
for vis_prob_1st in vis_probs:
# print("vis_prob_1st",vis_prob_1st.size())
# vis_prob_1st = F.upsample(vis_prob_1st.unsqueeze(0),size=(t,h,w),mode="trilinear")
# vis_prob_1st = vis_prob_1st.unsqueeze(2).repeat(1,1,3,1,1).reshape(-1,3,h,w)
#t 3 h w
if vis_prob_1st.size(0) ==1:
vis_prob_1st = vis_prob_1st.repeat(3,1,1,1)
vis += [vis_prob_1st.transpose(0,1)]
vis = torch.cat(vis,dim=0)
save_batch_img(vis,path,t)
from prettytable import PrettyTable
def count_parameters(model,black_key =[],only_key = ""):
str = ''
table = PrettyTable(["Modules", "Parameters","p-percetage"])
total_params = 0
for name, parameter in model.named_parameters():
continue_flag= False
for key in black_key:
if key in name or not only_key in name:
continue_flag =True
if continue_flag:
continue
if not parameter.requires_grad: continue
param = parameter.numel()
total_params+=param
for name, parameter in model.named_parameters():
continue_flag= False
for key in black_key:
if key in name or not only_key in name:
continue_flag =True
if continue_flag:
continue
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param, '{:.1%}'.format(param/total_params)])
# print(table)
str += table.__str__()
str += '\n'
str += f"Total Trainable Params: {total_params} \n"
total_params = 0
for n, parameter in model.named_parameters():
continue_flag= False
if ".f_net." in n\
:
if not parameter.requires_grad: continue
param = parameter.numel()
total_params+=param
str += f"Decoder Total Trainable Params: {total_params} \n"
total_params = 0
for n, parameter in model.named_parameters():
continue_flag= False
if ".sem_net." in n\
:
if not parameter.requires_grad: continue
param = parameter.numel()
total_params+=param
str += f"sem_net Total Trainable Params: {total_params} \n"
return str