-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdesignSeq.py
100 lines (88 loc) · 2.99 KB
/
designSeq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
with np.errstate(divide='ignore', invalid='ignore'):
iou = np.array(inter) / np.array(union)
return iou
def reorder(cls, box, o="xyxy", max_elem=None):
if o == "cxcywh":
box = box_cxcywh_to_xyxy(box)
if max_elem == None:
max_elem = len(cls)
# init
order = []
# convert
cls = np.array(cls)
area = box_area(box)
order_area = sorted(list(enumerate(area)), key=lambda x: x[1], reverse=True)
iou = box_iou(box, box)
# arrange
text = np.where(cls == 1)[0]
logo = np.where(cls == 2)[0]
deco = np.where(cls == 3)[0]
order_text = sorted(np.array(list(enumerate(area)))[text].tolist(), key=lambda x: x[1], reverse=True)
order_deco = sorted(np.array(list(enumerate(area)))[deco].tolist(), key=lambda x: x[1])
# deco connection
connection = {}
reverse_connection = {}
for idx, _ in order_deco:
con = []
for idx_ in logo:
if iou[idx, idx_]:
connection[idx_] = idx
con.append(idx_)
for idx_ in text:
if iou[idx, idx_]:
connection[idx_] = idx
con.append(idx_)
for idx_ in deco:
if idx == idx_: continue
if iou[idx, idx_]:
if idx_ not in connection:
connection[idx_] = [idx]
else:
connection[idx_].append(idx)
con.append(idx_)
reverse_connection[idx] = con
# reorder
for idx in logo:
if idx in connection:
d = connection[idx]
d_group = reverse_connection[d]
for idx_ in d_group:
if idx_ not in order:
order.append(idx_)
if d not in order:
order.append(d)
else:
order.append(idx)
for idx, _ in order_text:
if len(order) >= max_elem:
break
if idx in connection:
d = connection[idx]
d_group = reverse_connection[d]
for idx_ in d_group:
if idx_ not in order:
order.append(idx_)
if d not in order:
order.append(d)
else:
order.append(idx)
if len(order) < max_elem:
non_obj = np.where(cls == 0)[0]
order.extend(non_obj)
return order[:min(len(cls), max_elem)]