-
Notifications
You must be signed in to change notification settings - Fork 16
/
utils.py
357 lines (295 loc) · 12.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
"""
Licence : AIT JEDDI Yassine
Objectif : compute a confusion matrix for the whole test dataset
Reference : https://github.com/matterport/Mask_RCNN/
"""
"""
Note : copy this code in your original ulils.py file.
"""
from pandas import DataFrame
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from matplotlib.collections import QuadMesh
import seaborn as sn
from sklearn.metrics import confusion_matrix
from pandas import DataFrame
from string import ascii_uppercase
#function 1 to be added to your utils.py
def get_iou(a, b, epsilon=1e-5):
"""
Given two boxes `a` and `b` defined as a list of four numbers:
[x1,y1,x2,y2]
where:
x1,y1 represent the upper left corner
x2,y2 represent the lower right corner
It returns the Intersect of Union score for these two boxes.
Args:
a: (list of 4 numbers) [x1,y1,x2,y2]
b: (list of 4 numbers) [x1,y1,x2,y2]
epsilon: (float) Small value to prevent division by zero
Returns:
(float) The Intersect of Union score.
"""
# COORDINATES OF THE INTERSECTION BOX
x1 = max(a[0], b[0])
y1 = max(a[1], b[1])
x2 = min(a[2], b[2])
y2 = min(a[3], b[3])
# AREA OF OVERLAP - Area where the boxes intersect
width = (x2 - x1)
height = (y2 - y1)
# handle case where there is NO overlap
if (width<0) or (height <0):
return 0.0
area_overlap = width * height
# COMBINED AREA
area_a = (a[2] - a[0]) * (a[3] - a[1])
area_b = (b[2] - b[0]) * (b[3] - b[1])
area_combined = area_a + area_b - area_overlap
# RATIO OF AREA OF OVERLAP OVER COMBINED AREA
iou = area_overlap / (area_combined+epsilon)
return iou
#function 2 to be added to your utils.py
def gt_pred_lists(gt_class_ids, gt_bboxes, pred_class_ids, pred_bboxes, iou_tresh = 0.5):
"""
Given a list of ground truth and predicted classes and their boxes,
this function associates the predicted classes to their gt classes using a given Iou (Iou>= 0.5 for example) and returns
two normalized lists of len = N containing the gt and predicted classes,
filling the non-predicted and miss-predicted classes by the background class (index 0).
Args :
gt_class_ids : list of gt classes of size N1
pred_class_ids : list of predicted classes of size N2
gt_bboxes : list of gt boxes [N1, (x1, y1, x2, y2)]
pred_bboxes : list of pred boxes [N2, (x1, y1, x2, y2)]
Returns :
gt : list of size N
pred : list of size N
"""
#dict containing the state of each gt and predicted class (0 : not associated to any other class, 1 : associated to a class)
gt_class_ids_ = {'state' : [0*i for i in range(len(gt_class_ids))], "gt_class_ids":list(gt_class_ids)}
pred_class_ids_ = {'state' : [0*i for i in range(len(pred_class_ids))], "pred_class_ids":list(pred_class_ids)}
#the two lists to be returned
pred=[]
gt=[]
for i, gt_class in enumerate(gt_class_ids_["gt_class_ids"]):
for j, pred_class in enumerate(pred_class_ids_['pred_class_ids']):
#check if the gt object is overlapping with a predicted object
if get_iou(gt_bboxes[i], pred_bboxes[j])>=iou_tresh:
#change the state of the gt and predicted class when an overlapping is found
gt_class_ids_['state'][i] = 1
pred_class_ids_['state'][j] = 1
#gt.append(gt_class)
#pred.append(pred_class)
#chack if the overlapping objects are from the same class
if (gt_class == pred_class):
gt.append(gt_class)
pred.append(pred_class)
#if the overlapping objects are not from the same class
else :
gt.append(gt_class)
pred.append(pred_class)
#look for objects that are not predicted (gt objects that dont exists in pred objects)
for i, gt_class in enumerate(gt_class_ids_["gt_class_ids"]):
if gt_class_ids_['state'][i] == 0:
gt.append(gt_class)
pred.append(0)
#match_id += 1
#look for objects that are mispredicted (pred objects that dont exists in gt objects)
for j, pred_class in enumerate(pred_class_ids_["pred_class_ids"]):
if pred_class_ids_['state'][j] == 0:
gt.append(0)
pred.append(pred_class)
return gt, pred
######### Print confusion matrix for the whole dataset and return tp,fp and fn ##########
######### The style of this confusion matrix is inspired from https://github.com/wcipriano/pretty-print-confusion-matrix ##########
def get_new_fig(fn, figsize=[9,9]):
""" Init graphics """
fig1 = plt.figure(fn, figsize)
ax1 = fig1.gca() #Get Current Axis
ax1.cla() # clear existing plot
return fig1, ax1
#
def configcell_text_and_colors(array_df, lin, col, oText, facecolors, posi, fz, fmt, show_null_values=0):
"""
config cell text and colors
and return text elements to add and to dell
@TODO: use fmt
"""
text_add = []; text_del = [];
cell_val = array_df[lin][col]
tot_all = array_df[-1][-1]
per = (float(cell_val) / tot_all) * 100
curr_column = array_df[:,col]
ccl = len(curr_column)
#last line and/or last column
if(col == (ccl - 1)) or (lin == (ccl - 1)):
#tots and percents
if(cell_val != 0):
if(col == ccl - 1) and (lin == ccl - 1):
tot_rig = 0
for i in range(array_df.shape[0] - 1):
tot_rig += array_df[i][i]
per_ok = (float(tot_rig) / cell_val) * 100
elif(col == ccl - 1):
tot_rig = array_df[lin][lin]
per_ok = (float(tot_rig) / cell_val) * 100
elif(lin == ccl - 1):
tot_rig = array_df[col][col]
per_ok = (float(tot_rig) / cell_val) * 100
per_err = 100 - per_ok
else:
per_ok = per_err = 0
per_ok_s = ['%.2f%%'%(per_ok), '100%'] [per_ok == 100]
#text to DEL
text_del.append(oText)
#text to ADD
font_prop = fm.FontProperties(weight='bold', size=fz)
text_kwargs = dict(color='w', ha="center", va="center", gid='sum', fontproperties=font_prop)
lis_txt = ['%d'%(cell_val), per_ok_s, '%.2f%%'%(per_err)]
lis_kwa = [text_kwargs]
dic = text_kwargs.copy(); dic['color'] = 'g'; lis_kwa.append(dic);
dic = text_kwargs.copy(); dic['color'] = 'r'; lis_kwa.append(dic);
lis_pos = [(oText._x, oText._y-0.3), (oText._x, oText._y), (oText._x, oText._y+0.3)]
for i in range(len(lis_txt)):
newText = dict(x=lis_pos[i][0], y=lis_pos[i][1], text=lis_txt[i], kw=lis_kwa[i])
#print 'lin: %s, col: %s, newText: %s' %(lin, col, newText)
text_add.append(newText)
#print '\n'
#set background color for sum cells (last line and last column)
carr = [0.27, 0.30, 0.27, 1.0]
if(col == ccl - 1) and (lin == ccl - 1):
carr = [0.17, 0.20, 0.17, 1.0]
facecolors[posi] = carr
else:
if(per > 0):
txt = '%s\n%.2f%%' %(cell_val, per)
else:
if(show_null_values == 0):
txt = ''
elif(show_null_values == 1):
txt = '0'
else:
txt = '0\n0.0%'
oText.set_text(txt)
#main diagonal
if(col == lin):
#set color of the textin the diagonal to white
oText.set_color('w')
# set background color in the diagonal to blue
facecolors[posi] = [0.35, 0.8, 0.55, 1.0]
else:
oText.set_color('r')
return text_add, text_del
#
def insert_totals(df_cm):
""" insert total column and line (the last ones) """
sum_col = []
for c in df_cm.columns:
sum_col.append( df_cm[c].sum() )
sum_lin = []
for item_line in df_cm.iterrows():
sum_lin.append( item_line[1].sum() )
df_cm['sum_lin'] = sum_lin
sum_col.append(np.sum(sum_lin))
df_cm.loc['sum_col'] = sum_col
#print ('\ndf_cm:\n', df_cm, '\n\b\n')
#
def pretty_plot_confusion_matrix(df_cm, annot=True, cmap="Oranges", fmt='.2f', fz=11,
lw=0.5, cbar=False, figsize=[8,8], show_null_values=0, pred_val_axis='y'):
"""
print conf matrix with default layout (like matlab)
params:
df_cm dataframe (pandas) without totals
annot print text in each cell
cmap Oranges,Oranges_r,YlGnBu,Blues,RdBu, ... see:
fz fontsize
lw linewidth
pred_val_axis where to show the prediction values (x or y axis)
'col' or 'x': show predicted values in columns (x axis) instead lines
'lin' or 'y': show predicted values in lines (y axis)
"""
if(pred_val_axis in ('col', 'x')):
xlbl = 'Predicted'
ylbl = 'Actual'
else:
xlbl = 'Actual'
ylbl = 'Predicted'
df_cm = df_cm.T
# create "Total" column
insert_totals(df_cm)
#this is for print allways in the same window
fig, ax1 = get_new_fig('Conf matrix default', figsize)
#thanks for seaborn
sn.set(font_scale=1.8)
ax = sn.heatmap(df_cm, annot=annot, annot_kws={"size": fz}, linewidths=lw, ax=ax1,
cbar=cbar, cmap=cmap, linecolor='w', fmt=fmt)
#set ticklabels rotation
ax.set_xticklabels(ax.get_xticklabels(), rotation = 75, fontsize = 26)
ax.set_yticklabels(ax.get_yticklabels(), rotation = 25, fontsize = 26)
# Turn off all the ticks
for t in ax.xaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
for t in ax.yaxis.get_major_ticks():
t.tick1On = False
t.tick2On = False
#face colors list
quadmesh = ax.findobj(QuadMesh)[0]
facecolors = quadmesh.get_facecolors()
#iter in text elements
array_df = np.array( df_cm.to_records(index=False).tolist() )
text_add = []; text_del = [];
posi = -1 #from left to right, bottom to top.
for t in ax.collections[0].axes.texts: #ax.texts:
pos = np.array( t.get_position()) - [0.5,0.5]
lin = int(pos[1]); col = int(pos[0]);
posi += 1
#print ('>>> pos: %s, posi: %s, val: %s, txt: %s' %(pos, posi, array_df[lin][col], t.get_text()))
#set text
txt_res = configcell_text_and_colors(array_df, lin, col, t, facecolors, posi, fz, fmt, show_null_values)
text_add.extend(txt_res[0])
text_del.extend(txt_res[1])
#remove the old ones
for item in text_del:
item.remove()
#append the new ones
for item in text_add:
ax.text(item['x'], item['y'], item['text'], **item['kw'])
#titles and legends
ax.set_title('Confusion matrix')
ax.set_xlabel(xlbl)
ax.set_ylabel(ylbl)
plt.tight_layout() #set layout slim
plt.show()
#
def plot_confusion_matrix_from_data(y_test, predictions, columns=None, annot=True, cmap="Oranges",
fmt='.2f', fz=11, lw=0.5, cbar=False, figsize=[36,36], show_null_values=0, pred_val_axis='lin'):
"""
plot confusion matrix function with y_test (actual values) and predictions (predic),
whitout a confusion matrix yet
return the tp, fp and fn
"""
#data
if(not columns):
columns = ['class %s' %(i) for i in list(ascii_uppercase)[0:max(len(np.unique(y_test)),len(np.unique(predictions)))]]
y_test = np.array(y_test)
predictions = np.array(predictions)
#confusion matrix
confm = confusion_matrix(y_test, predictions)
num_classes = len(columns)
#compute tp fn fp
fp=[0]*num_classes
fn=[0]*num_classes
tp=[0]*num_classes
for i in range(confm.shape[0]):
fn[i]+=np.sum(confm[i])-np.diag(confm)[i]
fp[i]+=np.sum(np.transpose(confm)[i])-np.diag(confm)[i]
for j in range(confm.shape[1]):
if i==j:
tp[i]+=confm[i][j]
#plot
df_cm = DataFrame(confm, index=columns, columns=columns)
pretty_plot_confusion_matrix(df_cm, fz=fz, cmap=cmap, figsize=figsize, show_null_values=show_null_values,
pred_val_axis=pred_val_axis, lw=lw, fmt=fmt)
return tp, fp, fn