-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPlotPD.py
92 lines (75 loc) · 2.78 KB
/
PlotPD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from datasetPDEval import PDDataset
from models.model_bg import *
import torch.nn as nn
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import cv2
resolution = (480, 968)
model_path = '/home/users/skara/check_release/checkpoints/DIODPD_500.ckpt'
data_path = '/home/data/skara/test_video'
test_set = PDDataset(split = 'test', root = data_path)
# Define output directory
output = './infer'
os.makedirs(output, exist_ok=True)
model = SlotAttentionAutoEncoder(resolution, 45, 64, 3).to(device)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_path)['model_state_dict'])
print('model load finished!')
for param in model.module.parameters():
param.requires_grad = False
# plot on some scenes
for nb_scene in [0,1,2]:
sample = test_set[nb_scene]
images = sample['image']
cmap = plt.get_cmap('rainbow')
colors = [cmap(ii) for ii in np.linspace(0, 1, 45)]
image = images[0:5]
image = image.to(device)
image = image.unsqueeze(0)
recon_combined, masks,_, slots= model(image)
masks = masks.detach()
index_mask = masks.argmax(dim = 2)
index_mask = F.one_hot(index_mask,num_classes = 45)
index_mask = index_mask.permute(0,1,4,2,3)
masks = masks * index_mask
cur_image = F.interpolate(image, (3,120,242))
masks = masks[0]
cur_image = cur_image[0]
for j in range(5):
image_j = cur_image[j].permute(1,2,0).cpu().numpy()
image_j = image_j * 0.5 + 0.5
masks_j = masks[j]
tk = 44
masks_j = masks_j.cpu().numpy()
image_j = image_j[:,:,-1::-1]
fig = plt.figure(frameon=False)
ax = plt.Axes(fig, [0., 0., 1., 1.])
ax.axis('off')
fig.add_axes(ax)
ax.imshow(image_j, alpha = 1)
for seg in range(tk):
# fig = plt.figure(frameon=False)
# ax = plt.Axes(fig, [0., 0., 1., 1.])
# ax.axis('off')
# fig.add_axes(ax)
# ax.imshow(image_j, alpha = 1)
msk = masks_j[seg]
threshold = 0
e = (msk > threshold).astype('uint8')
contour, hier = cv2.findContours(
e.copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
cmax = None
for c in contour:
if cmax is None:
cmax = c
if len(c) > len(cmax):
cmax = c
if cmax is not None:
polygon = Polygon(
cmax.reshape((-1, 2)),
fill=True, facecolor=colors[seg],
edgecolor='w', linewidth=1.0,
alpha=0.5)
ax.add_patch(polygon)
fig.savefig(output + '/DIODPD_500_{}-{}.png'.format(nb_scene, j))
plt.close(fig)