forked from ivanzzh/admm_uav_regression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
118 lines (95 loc) · 6.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import torch
import pickle
import torchvision
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, average_precision_score
def read_pickle(filename):
infile = open(filename, 'rb')
data = pickle.load(infile)
infile.close()
return data
def dump_pickle(filename, data):
outfile = open(filename, "wb")
pickle.dump(data, filename)
outfile.close()
def visualize_sum_testing_result(path,init, prediction,sub_prediction, label, batch_id, epoch, batch_size):
assert prediction.shape[0] == label.shape[0], "prediction size and label size is not identical"
if not os.path.exists(path):
os.mkdir(path)
if not os.path.exists(path + "/epoch_" + str(epoch)):
os.mkdir(path + "/epoch_" + str(epoch))
if not os.path.exists(path + "/epoch_" + str(epoch) + "/sum"):
os.mkdir(path + "/epoch_" + str(epoch) + "/sum")
for idx, _ in enumerate(prediction):
init_output = init[idx].cpu().detach()
init_output = torch.squeeze(init_output)
prediction_output = prediction[idx].cpu().detach()
prediction_output = torch.squeeze(prediction_output)
# print("sub_prediction.shape ", sub_prediction.shape)
#4, 1, 60, 100, 100
sub_prediction_output_1 = sub_prediction[idx][30].cpu().detach()
sub_prediction_output_1 = torch.squeeze(sub_prediction_output_1)
sub_prediction_output_2 = sub_prediction[idx][31].cpu().detach()
sub_prediction_output_2 = torch.squeeze(sub_prediction_output_2)
label_output = label[idx].cpu().detach()
label_output = torch.squeeze(label_output)
torchvision.utils.save_image(init_output, path + "/epoch_" + str(\
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_init.png")
torchvision.utils.save_image(sub_prediction_output_1, path + "/epoch_" + str(
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_sub_prediction_1.png")
torchvision.utils.save_image(sub_prediction_output_2, path + "/epoch_" + str(
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_sub_prediction_2.png")
torchvision.utils.save_image(prediction_output, path + "/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_prediction.png")
torchvision.utils.save_image(label_output, path + "/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size ) + "_label.png")
def visualize_sum_testing_result_cont(path,init, prediction,sub_prediction, label, batch_id, epoch, batch_size, cont_index):
assert prediction.shape[0] == label.shape[0], "prediction size and label size is not identical"
if not os.path.exists(path):
os.mkdir(path)
if not os.path.exists(path + "/epoch_" + str(epoch)):
os.mkdir(path + "/epoch_" + str(epoch))
if not os.path.exists(path + "/epoch_" + str(epoch) + "/sum"):
os.mkdir(path + "/epoch_" + str(epoch) + "/sum")
for idx, _ in enumerate(prediction):
init_output = init[idx].cpu().detach()
init_output = torch.squeeze(init_output)
prediction_output = prediction[idx].cpu().detach()
prediction_output = torch.squeeze(prediction_output)
# print("sub_prediction.shape ", sub_prediction.shape)
#4, 1, 60, 100, 100
sub_prediction_output_1 = sub_prediction[idx][30].cpu().detach()
sub_prediction_output_1 = torch.squeeze(sub_prediction_output_1)
sub_prediction_output_2 = sub_prediction[idx][31].cpu().detach()
sub_prediction_output_2 = torch.squeeze(sub_prediction_output_2)
label_output = label[idx].cpu().detach()
label_output = torch.squeeze(label_output)
torchvision.utils.save_image(init_output, path + "/epoch_" + str(\
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_init_" + str(cont_index) + ".png")
torchvision.utils.save_image(sub_prediction_output_1, path + "/epoch_" + str(
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_sub_prediction_1.png")
torchvision.utils.save_image(sub_prediction_output_2, path + "/epoch_" + str(
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_sub_prediction_2.png")
torchvision.utils.save_image(prediction_output, path + "/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_prediction_" + str(cont_index) + ".png")
torchvision.utils.save_image(label_output, path + "/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size ) + "_label_" + str(cont_index) + ".png")
def visualize_sum_training_result(init, prediction,sub_prediction, label, batch_id, epoch, batch_size):
assert prediction.shape[0] == label.shape[0], "prediction size and label size is not identical"
if not os.path.exists("/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch)):
os.mkdir("/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch))
if not os.path.exists("/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch) + "/sum"):
os.mkdir("/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch) + "/sum")
for idx, _ in enumerate(prediction):
init_output = init[idx].cpu().detach()
init_output = torch.squeeze(init_output)
prediction_output = prediction[idx].cpu().detach()
prediction_output = torch.squeeze(prediction_output)
#print("sub_prediction.shape ", sub_prediction.shape)
sub_prediction_output = sub_prediction[idx][0][1].cpu().detach()
sub_prediction_output = torch.squeeze(sub_prediction_output)
#print("ub_prediction_output.shape ", sub_prediction_output.shape)
label_output = label[idx].cpu().detach()
label_output = torch.squeeze(label_output)
torchvision.utils.save_image(init_output, "/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(\
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_init.png")
torchvision.utils.save_image(sub_prediction_output, "/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(
epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_sub_prediction.png")
torchvision.utils.save_image(prediction_output, "/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size) + "_prediction.png")
torchvision.utils.save_image(label_output, "/home/share_uav/zzh/data/uav_regression/training_result/epoch_" + str(epoch) + "/sum" + "/" + str(idx + batch_id * batch_size ) + "_label.png")