-
Notifications
You must be signed in to change notification settings - Fork 14
/
utils.py
145 lines (110 loc) · 3.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import sys
import errno
import torch
import random
import numbers
from sklearn.metrics import f1_score, recall_score, average_precision_score, roc_auc_score
from sklearn.preprocessing import label_binarize
from sklearn.utils.multiclass import unique_labels
from torchvision.transforms import functional as F
def accuracy(output, target):
with torch.no_grad():
batch_size = target.size(0)
pred = torch.argmax(output, dim=1)
correct = pred.eq(target)
acc = correct.float().sum().mul_(1.0 / batch_size)
return acc, pred
def calc_metrics(y_pred, y_true, y_scores):
metrics = {}
y_pred = torch.cat(y_pred).cpu().numpy()
y_true = torch.cat(y_true).cpu().numpy()
y_scores = torch.cat(y_scores).cpu().numpy()
classes = unique_labels(y_true, y_pred)
# recall score
metrics['rec'] = recall_score(y_true, y_pred, average='macro')
# f1 score
f1_scores = f1_score(y_true, y_pred, average=None, labels=unique_labels(y_pred))
metrics['f1'] = f1_scores.sum() / classes.shape[0]
# AUC PR
Y = label_binarize(y_true, classes=classes.astype(int).tolist())
metrics['aucpr'] = average_precision_score(Y, y_scores, average='macro')
# AUC ROC
metrics['aucroc'] = roc_auc_score(Y, y_scores, average='macro')
return metrics
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
class AverageMeter(object):
""" Computes and stores the average and current value """
def __init__(self):
self.val = None
self.avg = None
self.sum = None
self.count = None
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class Logger(object):
console = sys.stdout
def __init__(self, fpath=None):
self.file = None
if fpath is not None:
mkdir_if_missing(os.path.dirname(fpath))
self.file = open(fpath, 'w')
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
if self.file is not None:
self.file.close()
def mkdir_if_missing(directory):
if not os.path.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
class RandomFiveCrop(object):
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size
def __call__(self, img):
# randomly return one of the five crops
return F.five_crop(img, self.size)[random.randint(0, 4)]
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)