-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
53 lines (44 loc) · 1.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
lr_base = 0.0003
def adjust_learning_rate(optimizer, epoch):
lr = lr_base * (0.8 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def batch_pix_accuracy(output, target):
"""Batch Pixel Accuracy
Args:
predict: input 4D tensor
target: label 3D tensor
"""
_, predict = torch.max(output, 1)
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
pixel_labeled = np.sum(target > 0)
pixel_correct = np.sum((predict == target)*(target > 0))
assert pixel_correct <= pixel_labeled, \
"Correct area should be smaller than Labeled"
return pixel_correct, pixel_labeled
def batch_intersection_union(output, target, nclass):
"""Batch Intersection of Union
Args:
predict: input 4D tensor
target: label 3D tensor
nclass: number of categories (int)
"""
_, predict = torch.max(output, 1)
mini = 1
maxi = nclass
nbins = nclass
predict = predict.cpu().numpy().astype('int64') + 1
target = target.cpu().numpy().astype('int64') + 1
predict = predict * (target > 0).astype(predict.dtype)
intersection = predict * (predict == target)
# areas of intersection and union
area_inter, _ = np.histogram(intersection, bins=nbins, range=(mini, maxi))
area_pred, _ = np.histogram(predict, bins=nbins, range=(mini, maxi))
area_lab, _ = np.histogram(target, bins=nbins, range=(mini, maxi))
area_union = area_pred + area_lab - area_inter
assert (area_inter <= area_union).all(), \
"Intersection area should be smaller than Union area"
return area_inter, area_union