-
Notifications
You must be signed in to change notification settings - Fork 2
/
msmarco_eval.py
199 lines (168 loc) · 7.9 KB
/
msmarco_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
This module computes evaluation metrics for MSMARCO dataset on the ranking task. Intenral hard coded eval files version. DO NOT PUBLISH!
Command line:
python msmarco_eval_ranking.py <path_to_candidate_file>
Creation Date : 06/12/2018
Last Modified : 4/09/2019
Authors : Daniel Campos <[email protected]>, Rutger van Haasteren <[email protected]>
"""
import sys
from collections import Counter
def load_reference_from_stream(f):
"""Load Reference reference relevant passages
Args:f (stream): stream to load.
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
"""
qids_to_relevant_passageids = {}
for l in f:
try:
l = l.strip().split('\t')
qid = int(l[0])
if qid in qids_to_relevant_passageids:
pass
else:
qids_to_relevant_passageids[qid] = []
qids_to_relevant_passageids[qid].append(int(l[1]))
except:
raise IOError('\"%s\" is not valid format' % l)
return qids_to_relevant_passageids
def load_reference(path_to_reference):
"""Load Reference reference relevant passages
Args:path_to_reference (str): path to a file to load.
Returns:qids_to_relevant_passageids (dict): dictionary mapping from query_id (int) to relevant passages (list of ints).
"""
with open(path_to_reference, 'r') as f:
qids_to_relevant_passageids = load_reference_from_stream(f)
return qids_to_relevant_passageids
def load_candidate_from_stream(f):
"""Load candidate data from a stream.
Args:f (stream): stream to load.
Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
"""
qid_to_ranked_candidate_passages = {}
for l in f:
try:
l = l.strip().split('\t')
qid = int(l[0])
pid = int(l[1])
rank = int(l[2])
if qid in qid_to_ranked_candidate_passages:
pass
else:
# By default, all PIDs in the list of 1000 are 0. Only override those that are given
tmp = [0] * 1000
qid_to_ranked_candidate_passages[qid] = tmp
qid_to_ranked_candidate_passages[qid][rank - 1] = pid
except:
raise IOError('\"%s\" is not valid format' % l)
return qid_to_ranked_candidate_passages
def load_candidate(path_to_candidate):
"""Load candidate data from a file.
Args:path_to_candidate (str): path to file to load.
Returns:qid_to_ranked_candidate_passages (dict): dictionary mapping from query_id (int) to a list of 1000 passage ids(int) ranked by relevance and importance
"""
with open(path_to_candidate, 'r') as f:
qid_to_ranked_candidate_passages = load_candidate_from_stream(f)
return qid_to_ranked_candidate_passages
def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
"""Perform quality checks on the dictionaries
Args:
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
Returns:
bool,str: Boolean whether allowed, message to be shown in case of a problem
"""
message = ''
allowed = True
# Create sets of the QIDs for the submitted and reference queries
candidate_set = set(qids_to_ranked_candidate_passages.keys())
ref_set = set(qids_to_relevant_passageids.keys())
# Check that we do not have multiple passages per query
for qid in qids_to_ranked_candidate_passages:
# Remove all zeros from the candidates
duplicate_pids = set(
[item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1])
if len(duplicate_pids - set([0])) > 0:
message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format(
qid=qid, pid=list(duplicate_pids)[0])
allowed = False
return allowed, message
def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
"""Compute MRR metric
Args:
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
Returns:
dict: dictionary of metrics {'MRR': <MRR Score>}
"""
all_scores = []
for MaxMRRRank in [10, 50, 100, 500, 1000]:
MRR = 0
ACC = 0
Recall = 0
qids_with_relevant_passages = 0
ranking = []
for qid in qids_to_ranked_candidate_passages:
if qid in qids_to_relevant_passageids:
ranking.append(0)
target_pid = qids_to_relevant_passageids[qid]
candidate_pid = qids_to_ranked_candidate_passages[qid]
for i in range(0, MaxMRRRank):
if candidate_pid[i] in target_pid:
ACC += 1
MRR += 1 / (i + 1)
ranking.pop()
ranking.append(i + 1)
break
Recall += (len(set(target_pid) & set(candidate_pid[:MaxMRRRank])) / len(target_pid))
if len(ranking) == 0:
raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
MRR = MRR / len(qids_to_relevant_passageids)
ACC = ACC / len(qids_to_relevant_passageids)
Recall = Recall / len(qids_to_relevant_passageids)
# all_scores.append(('ACC @{}'.format(MaxMRRRank), ACC))
all_scores.append(('MRR @{}'.format(MaxMRRRank), MRR))
all_scores.append(('Recall @{}'.format(MaxMRRRank), Recall))
all_scores.append(('QueriesRanked', len(qids_to_ranked_candidate_passages)))
return all_scores
def compute_metrics_from_files(path_to_reference, path_to_candidate, perform_checks=True):
"""Compute MRR metric
Args:
p_path_to_reference_file (str): path to reference file.
Reference file should contain lines in the following format:
QUERYID\tPASSAGEID
Where PASSAGEID is a relevant passage for a query. Note QUERYID can repeat on different lines with different PASSAGEIDs
p_path_to_candidate_file (str): path to candidate file.
Candidate file sould contain lines in the following format:
QUERYID\tPASSAGEID1\tRank
If a user wishes to use the TREC format please run the script with a -t flag at the end. If this flag is used the expected format is
QUERYID\tITER\tDOCNO\tRANK\tSIM\tRUNID
Where the values are separated by tabs and ranked in order of relevance
Returns:
dict: dictionary of metrics {'MRR': <MRR Score>}
"""
qids_to_relevant_passageids = load_reference(path_to_reference)
qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)
if perform_checks:
allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
if message != '': print(message)
return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
def main():
"""Command line:
python msmarco_eval_ranking.py <path to reference> <path_to_candidate_file>
"""
path_to_candidate = sys.argv[2]
path_to_reference = sys.argv[1]
path_to_save = sys.argv[3]
metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
print('#####################')
for x, y in (metrics):
print('{}: {}'.format(x, y))
print('#####################')
with open(path_to_save, 'w') as f:
for x, y in (metrics):
f.write('{}: {}\n'.format(x, y))
if __name__ == '__main__':
main()