forked from huaaaliu/RGBX_Semantic_Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
120 lines (102 loc) · 4.76 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import cv2
import argparse
import numpy as np
import torch
import torch.nn as nn
# from config import config
from utils.config_utils import get_config_by_file
from utils.pyt_utils import ensure_dir, link_file, load_model, parse_devices
from utils.visualize import print_iou, show_img, get_class_colors
from engine.evaluator import Evaluator
from engine.logger import get_logger
from utils.metric import hist_info, compute_score
from dataloader.RGBXDataset import RGBXDataset
from models.builder import EncoderDecoder as segmodel
from dataloader.dataloader import ValPre
from PIL import Image
logger = get_logger()
class SegEvaluator(Evaluator):
def func_per_iteration(self, data, device):
img = data['data']
label = data['label']
modal_x = data['modal_x']
name = data['fn']
pred = self.sliding_eval_rgbX(img, modal_x, config.eval_crop_size, config.eval_stride_rate, device)
hist_tmp, labeled_tmp, correct_tmp = hist_info(config.num_classes, pred, label)
results_dict = {'hist': hist_tmp, 'labeled': labeled_tmp, 'correct': correct_tmp}
if self.save_path is not None:
ensure_dir(self.save_path)
ensure_dir(self.save_path+'_color')
fn = name + '.png'
# save colored result
result_img = Image.fromarray(pred.astype(np.uint8), mode='P')
class_colors = get_class_colors(config.num_classes)
palette_list = list(np.array(class_colors).flat)
if len(palette_list) < 768:
palette_list += [0] * (768 - len(palette_list))
result_img.putpalette(palette_list)
result_img.save(os.path.join(self.save_path+'_color', fn))
# save raw result
cv2.imwrite(os.path.join(self.save_path, fn), pred)
logger.info('Save the image ' + fn)
if self.show_image:
colors = self.dataset.get_class_colors
image = img
clean = np.zeros(label.shape)
comp_img = show_img(colors, config.background, image, clean,
label,
pred)
cv2.imshow('comp_image', comp_img)
cv2.waitKey(0)
return results_dict
def compute_metric(self, results):
hist = np.zeros((config.num_classes, config.num_classes))
correct = 0
labeled = 0
count = 0
for d in results:
hist += d['hist']
correct += d['correct']
labeled += d['labeled']
count += 1
iou, mean_IoU, _, freq_IoU, mean_pixel_acc, pixel_acc = compute_score(hist, correct, labeled)
result_line = print_iou(iou, freq_IoU, mean_pixel_acc, pixel_acc,
dataset.class_names, show_no_back=False)
return result_line
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--config_file", default=None, type=str,
help="plz input your experiment description file",)
parser.add_argument('-e', '--epochs', default='last', type=str)
parser.add_argument('-d', '--devices', default='0', type=str)
parser.add_argument('-v', '--verbose', default=False, action='store_true')
parser.add_argument('--show_image', '-s', default=False,
action='store_true')
parser.add_argument('--save_path', '-p', default=None)
args = parser.parse_args()
all_dev = parse_devices(args.devices)
config = get_config_by_file(args.config_file)
network = segmodel(cfg=config, criterion=None, norm_layer=nn.BatchNorm2d)
data_setting = {'rgb_root': config.rgb_root_folder,
'rgb_format': config.rgb_format,
'gt_root': config.gt_root_folder,
'gt_format': config.gt_format,
'transform_gt': config.gt_transform,
'x_root':config.x_root_folder,
'x_format': config.x_format,
'x_single_channel': config.x_is_single_channel,
'class_names': config.class_names,
'train_source': config.train_source,
'eval_source': config.eval_source,
'class_names': config.class_names}
val_pre = ValPre()
dataset = RGBXDataset(data_setting, 'val', val_pre)
with torch.no_grad():
segmentor = SegEvaluator(dataset, config.num_classes, config.norm_mean,
config.norm_std, network,
config.eval_scale_array, config.eval_flip,
all_dev, args.verbose, args.save_path,
args.show_image)
segmentor.run(config.checkpoint_dir, args.epochs, config.val_log_file,
config.link_val_log_file)