-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_metrics_eval.py
executable file
·54 lines (43 loc) · 1.81 KB
/
run_metrics_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
import pickle
from tqdm import tqdm
import numpy as np
import json
import argparse
from metrics import *
def read_file(fname, transform=lambda x: x):
data = []
with open(fname) as f:
for line in f:
data.append(transform(line.strip()))
return data
def get_outputs(predictions_file, lower=False, truncate=False, max_length=None):
predictions = read_file(predictions_file)
outputs = []
for i, cand in enumerate(predictions):
if lower:
cand = cand.lower()
if truncate:
if isinstance(max_length, list):
cand = cand[:max_length[i]]
else:
cand = cand[:max_length]
outputs.append(cand)
return outputs
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--domain", type=str, default='medical')
parser.add_argument("-s", "--split", type=str, default='test')
parser.add_argument("--metric-class-name", type=str, default="COMETSrcMetric")
parser.add_argument("--output-file", type=str, default=None, help="output text file from running XGLM.")
parser.add_argument("--target-file", type=str, default=None, help="Reference corresponding to the outputs.")
parser.add_argument("--source-file", type=str, default=None, help="Source corresponding to the outputs to estimate length truncation.")
args = parser.parse_args()
metric = getattr(sys.modules[__name__], args.metric_class_name)()
src = read_file(f"{args.source_file}")
ref = read_file(f"{args.target_file}")
lengths = [len(x)*2 for x in src]
outputs = get_outputs(f"{args.output_file}", truncate=True, max_length=lengths)
print(args.metric_class_name, metric.get_score(ref, outputs)[0])
if __name__ == '__main__':
main()