forked from SawanKumar28/alc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmisc_utils.py
26 lines (23 loc) · 857 Bytes
/
misc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from collections import defaultdict
import numpy as np
import ipdb as pdb
#Following https://github.com/jzbjyb/lm-calibration/blob/master/cal.py
def compute_ece(eval_results, conf_values, num_bins=20):
bin_size = 1.0/num_bins
bins = defaultdict(list)
for idx in range(len(conf_values)):
conf = max(0, conf_values[idx])
bin_index = min(int(conf/bin_size), num_bins -1)
bins[bin_index].append([conf, eval_results[idx]])
ece = 0
total_count = 0
for bin_index, vals in bins.items():
count = len(vals)
if count <= 0:
continue
total_count = total_count + count
mean_conf = np.mean([item[0] for item in vals])
mean_acc = np.mean([item[1] for item in vals])
ece = ece + count * np.abs(mean_conf-mean_acc)
ece = ece / total_count
return ece