-
Notifications
You must be signed in to change notification settings - Fork 3
/
generate_prediction.py
97 lines (78 loc) · 3.39 KB
/
generate_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""i is the id number of the corresponding image. In order to ensure that the data can be distinguished, we divide it
into 0-8348 for the train set, 8349-9511 for the verification set, and 9512-11958 for the test set. Therefore,
when running the script, you need to modify the size of the reference data set to modify the corresponding value,
and modify the name of the saved file and the name of the corresponding variable. """
import json
import argparse
import os
def construct_prediction_json(data_root, result_path):
prediction_dict = {}
with open(os.path.join(data_root, 'token_val.json'), encoding='utf-8') as f1:
data_token = json.load(f1)
with open(os.path.join(result_path, "val_prediction.json"), encoding='utf-8') as f2:
data_prediction = json.load(f2)
i = 8349
cnt = 0
while i < 8349 + 1163:
token = data_token[cnt][str(i)]
bbox = data_prediction[cnt][str(i)]
i += 1
cnt += 1
prediction_dict[token] = bbox
with open(os.path.join(data_root, 'token_test.json'), encoding='utf-8') as f1:
data_token = json.load(f1)
with open(os.path.join(result_path, "test_prediction.json"), encoding='utf-8') as f2:
data_prediction = json.load(f2)
i = 8349 + 1163
cnt = 0
while i < 8349 + 1163 + 2447:
token = data_token[cnt][str(i)]
bbox = data_prediction[cnt][str(i)]
i += 1
cnt += 1
prediction_dict[token] = bbox
with open(os.path.join(result_path, 'predictions_for_leaderboard.json'), 'w') as f:
json.dump(prediction_dict, f)
"""
def construct_prediction_for_val_set():
with open('./token_val.json', encoding='utf-8') as f1:
prediction_dict = {}
data_token = json.load(f1)
with open("./val_prediction.json", encoding='utf-8') as f2:
data_prediction = json.load(f2)
i = 8349
cnt = 0
while i < 8349 + 1163:
token = data_token[cnt][str(i)]
bbox = data_prediction[cnt][str(i)]
i += 1
cnt += 1
prediction_dict[token] = bbox
with open('predictions_val.json', 'w') as f:
json.dump(prediction_dict, f)
def construct_prediction_for_test_set():
with open('./token_test.json', encoding='utf-8') as f1:
prediction_dict = {}
data_token = json.load(f1)
with open("./test_prediction.json", encoding='utf-8') as f2:
data_prediction = json.load(f2)
i = 8349+1163
cnt = 0
while i < 8349 + 1163+2447:
token = data_token[cnt][str(i)]
bbox = data_prediction[cnt][str(i)]
i += 1
cnt += 1
prediction_dict[token] = bbox
with open('predictions_test.json', 'w') as f:
json.dump(prediction_dict, f)
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Model
parser.add_argument("--data_root", default="data/talk2car", type=str)
parser.add_argument("--result_path", required=True, type=str)
args = parser.parse_args()
construct_prediction_json(args.data_root, args.result_path)
#construct_prediction_for_val_set()
#construct_prediction_for_test_set()