-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
110 lines (94 loc) · 3.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import datasets
import sacrebleu
import numpy as np
import evaluate
import string
import re
import collections
from rouge_score import rouge_scorer, scoring
f1_gen = evaluate.load("./metrics/f1")
exact_match = evaluate.load("./metrics/exact_match")
sep_tokens = ["<unused2>", "<0x02>", "<|reserved_special_token_2|>"]
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_punc(lower(s)))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def process_docs_gen(dataset: datasets.Dataset) -> datasets.Dataset:
return dataset.map(preprocess_function)
def preprocess_function(examples):
def _extract_facts(docs):
facts = []
# context_len = len(context["title"])
docs = list(filter(lambda doc: doc.strip(), docs))
docs_len = len(docs)
if docs_len > 5:
docs_len = 5
for i in range(docs_len):
'''
title = context["title"][i]
text = "\n".join(context["sentences"][i])
facts.append(title + ":\n" + text)
'''
fact = docs[i]
facts.append(f"{i + 1}. {fact}")
return facts
facts = _extract_facts(examples["docs"])
facts = "\n\n".join(list(set(facts)))
return {
"question": examples["question"],
"answer": examples["answer"],
"facts": facts.strip(),
}
def process_results(doc, results):
completion = results[0]
for sep_token in sep_tokens:
if sep_token in completion:
completion = completion.split(sep_token)[1]
answers = doc["answer"]
best_em_score = 0
best_f1_score = 0
for ans in answers:
exact_score = exact_match.compute(references=[ans], predictions=[completion], ignore_punctuation=True)["exact_match"]
ans_toks = get_tokens(ans)
completion_toks = get_tokens(completion)
common = collections.Counter(ans_toks) & collections.Counter(completion_toks)
num_same = sum(common.values())
if num_same == 0:
f1_score = 0
elif len(ans_toks) == 0 or len(completion_toks) == 0:
f1_score = int(ans_toks == completion_toks)
else:
precision = 1.0 * num_same / len(completion_toks)
recall = 1.0 * num_same / len(ans_toks)
f1_score = (2 * precision * recall) / (precision + recall)
best_em_score = max(best_em_score, exact_score)
best_f1_score = max(best_f1_score, f1_score)
return {"exact_match_score": best_em_score, "f1_score": best_f1_score}
def f1(**kwargs):
references = kwargs["references"]
predictions = kwargs["predictions"]
ref_toks = get_tokens(references[0])
pred_toks = get_tokens(predictions[0][0])
# print("ref_toks: ", ref_toks, "pred_toks: ", pred_toks)
common = collections.Counter(ref_toks) & collections.Counter(pred_toks)
# print("common: ", common)
num_same = sum(common.values())
if len(ref_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(ref_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(ref_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1