Skip to content

Commit d6fd5f3

Browse files
author
zhang-can
committed
testset support
1 parent acd99eb commit d6fd5f3

File tree

4 files changed

+257
-3
lines changed

4 files changed

+257
-3
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,7 @@ ENV/
104104

105105
data/*.txt
106106

107-
*.pth.tar
107+
*.pth.tar
108+
*.pth
109+
110+
output/

gen_dataset_test_lists.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# processing the raw data of the video datasets (something-something and jester)
2+
# generate the meta files:
3+
# dataset_test.txt: each row contains [video_path num_frames video_name]
4+
#
5+
# Created by: Can Zhang
6+
# github: @zhang-can, May,28th 2018
7+
#
8+
9+
import argparse
10+
import os
11+
import pdb
12+
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('dataset', type=str, choices=['something', 'jester'])
15+
parser.add_argument('frame_path', type=str, help="root directory holding the frames")
16+
parser.add_argument('--labels_path', type=str, default='data/dataset_labels/', help="root directory holding the csv files: labels, train & validation")
17+
parser.add_argument('--out_list_path', type=str, default='data/')
18+
19+
args = parser.parse_args()
20+
21+
dataset = args.dataset
22+
labels_path = args.labels_path
23+
frame_path = args.frame_path
24+
25+
if dataset == 'something':
26+
dataset_name = 'something-something-v1'
27+
elif dataset == 'jester':
28+
dataset_name = 'jester-v1'
29+
30+
# print('\nProcessing dataset: {}\n'.format(dataset))
31+
32+
# print('- Generating {}_category.txt ......'.format(dataset))
33+
# with open(os.path.join(labels_path, '{}-labels.csv'.format(dataset_name))) as f:
34+
# lines = f.readlines()
35+
# categories = []
36+
# for line in lines:
37+
# line = line.rstrip()
38+
# categories.append(line)
39+
# categories = sorted(categories)
40+
# open(os.path.join(args.out_list_path, '{}_category.txt'.format(dataset)),'w').write('\n'.join(categories))
41+
# print('- Saved as:', os.path.join(args.out_list_path, '{}_category.txt!\n'.format(dataset)))
42+
43+
# dict_categories = {}
44+
# for i, category in enumerate(categories):
45+
# dict_categories[category] = i
46+
47+
files_input = ['{}-test.csv'.format(dataset_name)]
48+
files_output = ['{}_test.txt'.format(dataset)]
49+
for (filename_input, filename_output) in zip(files_input, files_output):
50+
with open(os.path.join(labels_path, filename_input)) as f:
51+
lines = f.readlines()
52+
folders = []
53+
for line in lines:
54+
line = line.rstrip()
55+
items = line
56+
folders.append(items)
57+
output = []
58+
for i in range(len(folders)):
59+
curFolder = folders[i]
60+
# counting the number of frames in each video folders
61+
dir_files = os.listdir(os.path.join(frame_path, curFolder))
62+
output.append('{} {} {}'.format(os.path.join(frame_path, curFolder), len(dir_files), curFolder))
63+
if i % 1000 == 0:
64+
print('- Generating {} ({}/{})'.format(filename_output, i, len(folders)))
65+
with open(os.path.join(args.out_list_path, filename_output),'w') as f:
66+
f.write('\n'.join(output))
67+
print('- Saved as:', os.path.join(args.out_list_path, '{}!\n'.format(filename_output)))

test_models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# options
1515
parser = argparse.ArgumentParser(
1616
description="Standard video-level testing")
17-
parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics'])
17+
parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics', 'something'])
1818
parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff'])
1919
parser.add_argument('test_list', type=str)
2020
parser.add_argument('weights', type=str)
@@ -32,6 +32,7 @@
3232
help='number of data loading workers (default: 4)')
3333
parser.add_argument('--gpus', nargs='+', type=int, default=None)
3434
parser.add_argument('--flow_prefix', type=str, default='')
35+
parser.add_argument('--rgb_prefix', type=str, default='')
3536

3637
args = parser.parse_args()
3738

@@ -42,6 +43,8 @@
4243
num_class = 51
4344
elif args.dataset == 'kinetics':
4445
num_class = 400
46+
elif args.dataset == 'something':
47+
num_class = 174
4548
else:
4649
raise ValueError('Unknown dataset '+args.dataset)
4750

@@ -73,7 +76,7 @@
7376
TSNDataSet("", args.test_list, num_segments=args.test_segments,
7477
new_length=1 if args.modality == "RGB" else 5,
7578
modality=args.modality,
76-
image_tmpl="img_{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg",
79+
image_tmpl=args.rgb_prefix+"{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg",
7780
test_mode=True,
7881
transform=torchvision.transforms.Compose([
7982
cropping,

test_models_for_test.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
import argparse
2+
import time
3+
4+
import numpy as np
5+
import torch.nn.parallel
6+
import torch.optim
7+
from sklearn.metrics import confusion_matrix
8+
9+
from dataset import TSNDataSet
10+
from models import TSN
11+
from transforms import *
12+
from ops import ConsensusModule
13+
14+
import os
15+
16+
# options
17+
parser = argparse.ArgumentParser(
18+
description="Standard video-level testing")
19+
parser.add_argument('dataset', type=str, choices=['ucf101', 'hmdb51', 'kinetics', 'something'])
20+
parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff'])
21+
parser.add_argument('test_list', type=str)
22+
parser.add_argument('weights', type=str)
23+
parser.add_argument('result_file', type=str)
24+
parser.add_argument('--arch', type=str, default="resnet101")
25+
parser.add_argument('--save_scores', type=str, default=None)
26+
parser.add_argument('--test_segments', type=int, default=25)
27+
parser.add_argument('--max_num', type=int, default=-1)
28+
parser.add_argument('--test_crops', type=int, default=10)
29+
parser.add_argument('--input_size', type=int, default=224)
30+
parser.add_argument('--crop_fusion_type', type=str, default='avg',
31+
choices=['avg', 'max', 'topk'])
32+
parser.add_argument('--k', type=int, default=3)
33+
parser.add_argument('--dropout', type=float, default=0.7)
34+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35+
help='number of data loading workers (default: 4)')
36+
parser.add_argument('--gpus', nargs='+', type=int, default=None)
37+
parser.add_argument('--flow_prefix', type=str, default='')
38+
parser.add_argument('--rgb_prefix', type=str, default='')
39+
parser.add_argument('--out_list_path', type=str, default='data/')
40+
41+
args = parser.parse_args()
42+
43+
if args.dataset == 'ucf101':
44+
num_class = 101
45+
elif args.dataset == 'hmdb51':
46+
num_class = 51
47+
elif args.dataset == 'kinetics':
48+
num_class = 400
49+
elif args.dataset == 'something':
50+
num_class = 174
51+
else:
52+
raise ValueError('Unknown dataset '+args.dataset)
53+
54+
net = TSN(num_class, 1, args.modality,
55+
base_model=args.arch,
56+
consensus_type=args.crop_fusion_type,
57+
dropout=args.dropout)
58+
59+
checkpoint = torch.load(args.weights)
60+
print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
61+
62+
# list element type: tuple
63+
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
64+
net.load_state_dict(base_dict)
65+
66+
if args.test_crops == 1:
67+
cropping = torchvision.transforms.Compose([
68+
GroupScale(net.scale_size),
69+
GroupCenterCrop(net.input_size),
70+
])
71+
elif args.test_crops == 10:
72+
cropping = torchvision.transforms.Compose([
73+
GroupOverSample(net.input_size, net.scale_size)
74+
])
75+
else:
76+
raise ValueError("Only 1 and 10 crops are supported while we got {}".format(args.test_crops))
77+
78+
data_loader = torch.utils.data.DataLoader(
79+
TSNDataSet("", args.test_list, num_segments=args.test_segments,
80+
new_length=1 if args.modality == "RGB" else 5,
81+
modality=args.modality,
82+
image_tmpl=args.rgb_prefix+"{:05d}.jpg" if args.modality in ['RGB', 'RGBDiff'] else args.flow_prefix+"{}_{:05d}.jpg",
83+
test_mode=True,
84+
transform=torchvision.transforms.Compose([
85+
cropping,
86+
Stack(roll=args.arch == 'BNInception'),
87+
ToTorchFormatTensor(div=args.arch != 'BNInception'),
88+
GroupNormalize(net.input_mean, net.input_std),
89+
])),
90+
batch_size=1, shuffle=False,
91+
num_workers=args.workers * 2, pin_memory=True)
92+
93+
if args.gpus is not None:
94+
devices = [args.gpus[i] for i in range(args.workers)]
95+
else:
96+
devices = list(range(args.workers))
97+
98+
99+
net = torch.nn.DataParallel(net.cuda(devices[0]), device_ids=devices)
100+
net.eval()
101+
102+
data_gen = enumerate(data_loader)
103+
104+
total_num = len(data_loader.dataset)
105+
output = []
106+
107+
108+
def eval_video(video_data):
109+
i, data, label = video_data
110+
num_crop = args.test_crops
111+
112+
if args.modality == 'RGB':
113+
length = 3
114+
elif args.modality == 'Flow':
115+
length = 10
116+
elif args.modality == 'RGBDiff':
117+
length = 18
118+
else:
119+
raise ValueError("Unknown modality "+args.modality)
120+
121+
input_var = torch.autograd.Variable(data.view(-1, length, data.size(2), data.size(3)),
122+
volatile=True)
123+
rst = net(input_var).data.cpu().numpy().copy()
124+
return i, rst.reshape((num_crop, args.test_segments, num_class)).mean(axis=0).reshape(
125+
(args.test_segments, 1, num_class)
126+
), label[0]
127+
128+
129+
proc_start_time = time.time()
130+
max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)
131+
132+
for i, (data, label) in data_gen:
133+
if i >= max_num:
134+
break
135+
rst = eval_video((i, data, label))
136+
output.append(rst[1:])
137+
cnt_time = time.time() - proc_start_time
138+
print('video {} done, total {}/{}, average {} sec/video'.format(i, i+1,
139+
total_num,
140+
float(cnt_time) / (i+1)))
141+
142+
video_pred = [np.argmax(np.mean(x[0], axis=0)) for x in output]
143+
144+
video_ids = [x[1] for x in output]
145+
146+
category_lines = open(os.path.join(args.out_list_path, '{}_category.txt'.format(args.dataset))).readlines()
147+
categories = [line.rstrip() for line in category_lines]
148+
149+
test_results = ["{};{}".format(video_ids[i], categories[int(video_pred[i])]) for i in range(len(output))]
150+
151+
open(os.path.join(args.result_file),'w').write('\n'.join(test_results))
152+
153+
# cf = confusion_matrix(video_labels, video_pred).astype(float)
154+
155+
# cls_cnt = cf.sum(axis=1)
156+
# cls_hit = np.diag(cf)
157+
158+
# cls_acc = cls_hit / cls_cnt
159+
160+
# print(cls_acc)
161+
162+
# print('Accuracy {:.02f}%'.format(np.mean(cls_acc) * 100))
163+
164+
# if args.save_scores is not None:
165+
166+
# # reorder before saving
167+
# name_list = [x.strip().split()[0] for x in open(args.test_list)]
168+
169+
# order_dict = {e:i for i, e in enumerate(sorted(name_list))}
170+
171+
# reorder_output = [None] * len(output)
172+
# reorder_label = [None] * len(output)
173+
174+
# for i in range(len(output)):
175+
# idx = order_dict[name_list[i]]
176+
# reorder_output[idx] = output[i]
177+
# reorder_label[idx] = video_labels[i]
178+
179+
# np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)
180+
181+

0 commit comments

Comments
 (0)