-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathval_dst.py
184 lines (164 loc) · 6.9 KB
/
val_dst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
import copy
from tqdm import tqdm
import numpy as np
import hdf5plugin # resolve a weird h5py error
import torch
from torch.backends import cuda, cudnn
cuda.matmul.allow_tf32 = True
cudnn.allow_tf32 = True
torch.multiprocessing.set_sharing_strategy('file_system')
import hydra
from omegaconf import DictConfig, OmegaConf
from config.modifier import dynamically_modify_train_config
from modules.utils.fetch import fetch_data_module
from modules.utils.ssod import evaluate_label, filter_w_thresh
from data.genx_utils.labels import ObjectLabels
from data.utils.types import DataType
from nerv.utils import AverageMeter, glob_all
empty_box = torch.zeros((0, 8)).float()
empty_box = ObjectLabels(empty_box, input_size_hw=(360, 640))
def filter_bbox(pred, obj_thresh=0.9, cls_thresh=0.9, ignore_label=1024):
"""Filter with objectness or class confidence scores."""
# pred: (N, 7) th.tensor, [(x1, y1, x2, y2), obj_conf, cls_conf, cls_idx]
obj_conf, cls_conf, cls_idx = \
pred.objectness, pred.class_confidence, pred.class_id
sel_mask = (filter_w_thresh(obj_conf, cls_idx, obj_thresh)) & \
(filter_w_thresh(cls_conf, cls_idx, cls_thresh)) & \
(cls_idx != ignore_label)
pred.object_labels = pred.object_labels[sel_mask]
return pred
@torch.inference_mode()
def eval_one_seq(full_cfg, pse_batch, batch, skip_gt=False):
"""Run the model on one event sequence and visualize it."""
dst_cfg, mdl_cfg = full_cfg.dataset, full_cfg.model
pse_data, data = pse_batch['data'], batch['data']
# sanity check
assert os.path.basename(pse_data[DataType.PATH][0]) == \
os.path.basename(data[DataType.PATH][0])
# get labels
pse_obj_labels = pse_data[DataType.OBJLABELS_SEQ]
pse_obj_labels = [lbl[0] for lbl in pse_obj_labels]
# loaded GT and skipped GT
loaded_obj_labels = data[DataType.OBJLABELS_SEQ]
loaded_obj_labels = [lbl[0] for lbl in loaded_obj_labels]
skipped_obj_labels = data[DataType.SKIPPED_OBJLABELS_SEQ]
skipped_obj_labels = [lbl[0] for lbl in skipped_obj_labels]
# all are `L`-len list of ObjectLabels or None
pse_labels, skipped_labels = [], []
for pse_lbl, loaded_lbl, skipped_lbl in \
zip(pse_obj_labels, loaded_obj_labels, skipped_obj_labels):
if loaded_lbl is not None and not skip_gt:
assert loaded_lbl == pse_lbl, 'GT labels mismatch'
elif pse_lbl is not None:
assert pse_lbl.is_pseudo_label().all(), 'Contain GT labels'
if skipped_lbl is not None:
skipped_labels.append(skipped_lbl.to('cuda'))
if pse_lbl is None:
pse_lbl = empty_box.to('cuda')
else:
pse_lbl = pse_lbl.to('cuda')
# filter with score_thresh
pse_lbl = filter_bbox(
pse_lbl,
obj_thresh=mdl_cfg.pseudo_label.obj_thresh,
cls_thresh=mdl_cfg.pseudo_label.cls_thresh,
ignore_label=mdl_cfg.pseudo_label.ignore_label)
pse_labels.append(pse_lbl)
# evaluate
# in case some sequences have no skipped GT
if len(skipped_labels) == 0:
return {}
# Precision & Recall
pred_mask = np.ones(len(skipped_labels), dtype=bool)
metrics = evaluate_label(
skipped_labels,
pse_labels,
pred_mask,
num_cls=dst_cfg.num_classes,
prefix='ssod/')
return metrics
def eval_one_dataset(config: DictConfig):
# ---------------------
# Data
# ---------------------
dst_name = config.dataset.name
config.batch_size.eval = 1
config.hardware.num_workers.eval //= 2 # half-half for ori_dst & pse_dst
config.dataset.sequence_length = 320 if dst_name == 'gen1' else 128
config.dataset.only_load_labels = True
config.dataset.data_augmentation.stream.start_from_zero = True
# we first get the pesudo dataset generated by `predict.py`
# load everything by setting ratio to -1
sparse_ratio, subseq_ratio, pse_path = \
config.dataset.ratio, config.dataset.train_ratio, config.dataset.path
if sparse_ratio == -1:
assert 0. < subseq_ratio < 1.
config.dataset.train_ratio = -1
print('Evaluating WSOD dataset')
else:
assert 0. < sparse_ratio < 1.
config.dataset.ratio = -1
print('Evaluating SSOD dataset')
skip_gt = ('all_pse' in pse_path)
if skip_gt:
print('Pseudo dataset that does not take GT labels')
pse_data_module = fetch_data_module(config=config)
pse_data_module.setup(stage='predict')
pse_loader = pse_data_module.predict_dataloader()
# then we get the original (sub-sampled) dataset
if sparse_ratio == -1:
config.dataset.train_ratio = subseq_ratio
else:
config.dataset.ratio = sparse_ratio
config.dataset.path = f'./datasets/{dst_name}'
data_module = fetch_data_module(config=config)
data_module.setup(stage='predict')
loader = data_module.predict_dataloader()
print('\nobj_thresh:', config.model.pseudo_label.obj_thresh)
print('cls_thresh:', config.model.pseudo_label.cls_thresh, '\n')
# ---------------------
# Evaluate pseudo labels VS GT labels
# ---------------------
metrics_dict = {}
for (pse_batch, batch) in tqdm(
zip(pse_loader, loader), desc='Event sequence'):
metrics = eval_one_seq(config, pse_batch, batch, skip_gt=skip_gt)
for k, v in metrics.items():
if k.startswith('num_'):
continue
if k not in metrics_dict:
metrics_dict[k] = AverageMeter()
cls_name = k.split('_')[-1] # xxx_car
metrics_dict[k].update(v, n=metrics[f'num_{cls_name}'])
print('------------ Evaluation ------------')
for k, v in metrics_dict.items():
print(f'\t{k}: {v.avg:.4f}')
print('------------------------------------')
print('Dataset path:', pse_path)
@hydra.main(config_path='config', config_name='val', version_base='1.2')
def main(config: DictConfig):
dynamically_modify_train_config(config)
# Just to check whether config can be resolved
OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
# if given a valid dataset path, we eval it
dst_path = config.dataset.path
if os.path.exists(os.path.join(dst_path, 'train')):
eval_one_dataset(config)
exit(-1)
# otherwise, we eval all folders under it
all_dst_paths = glob_all(dst_path, only_dir=True)
for dst_path in all_dst_paths:
if not os.path.exists(os.path.join(dst_path, 'train')):
continue
one_config = copy.deepcopy(config)
one_config.dataset.path = dst_path
eval_one_dataset(one_config)
if __name__ == '__main__':
main()