-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval_bm25.py
119 lines (108 loc) · 3.78 KB
/
eval_bm25.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import json
import argparse
import os
from evaluate import Evaluate
from conferenceqa import ConferenceQA
import tqdm
from utils.api_request_parallel_processor import run
from collections import defaultdict as ddict
def evaluate_answer(
file_name: str,
results,
gold,
):
prompts = []
d1 = {}
for item in results:
try:
d1[item["question"]] = item["answer"]
except:
d1[item["query"]] = item["output"]
d2 = {}
for item in gold:
try:
d2[item["question"]] = item["answer"]
except:
d2[item["query"]] = item["output"]
num = 0
for step, item in tqdm.tqdm(enumerate(results), total=len(results)):
try:
query = item["question"]
except:
query = item["query"]
prompt = """
You are a judge and need to judge the effect of generated answer.
[task definition]
Give you two sentences, both of which are answers to the same query, the first sentence is generated by language models, and the second sentence is the reference answer.
You need to capture the key information in the reference answer according to the query, and then judge whether the generated answer contains the key information.
Returns true if the generated answer contains most of the key information, false if the generated answer is wrong.
As long as it contains key information, it is correct, regardless of whether there is any other uncertain content.
[note]
You only output true or false.
You should justify your answer, and not let the order of the answers affect it.
"""
prompt += "\n\n" + "query:" + query
prompt += "\n\n" + f"generated answer: {d1[query]}"
prompt += "\n\n" + f"reference answer: {d2[query]}"
item = {"prompt": prompt, "query": ""}
prompts.append(item)
resps = run(prompts)
with open(f'{file_name}.json', 'w') as f:
json.dump(resps, f, ensure_ascii=False)
for resp in resps:
if "true" in resp["output"].lower():
num += 1
results.append((num, len(results)))
with open("bm25.txt", "a") as f:
f.write(f"{file_name}: ({num}/{len(results)})" + "\n")
print(f"{file_name}: ({num}/{len(results)})")
parser = argparse.ArgumentParser()
parser.add_argument("--cfe_name", default="WWW", type=str, help="")
parser.add_argument(
"--encoder",
default="text-embedding-002",
type=str,
choices=["text-embedding-002", "SentenceBERT", "ANCE"],
help="",
)
parser.add_argument(
"--retrieve_method",
default="desc_leaf",
type=str,
choices=[
"desc_leaf",
"desc_value",
"path",
"path_and_value",
"desc",
"desc_and_value",
"desc_and_path_value",
],
help="",
)
parser.add_argument("--dicts_path", default="dataset/WWW/dicts", type=str, help="")
parser.add_argument("--embedding_bs", default=200, type=int, help="")
parser.add_argument(
"--distance",
default="cosine",
type=str,
choices=["cosine", "l2", "ip"],
help="",
)
parser.add_argument(
"--persist_chroma_path", default="embeddings/WWW", type=str, help=""
)
parser.add_argument("--persist_csv_path", default="embeddings/WWW", type=str, help="")
args = parser.parse_args()
qa = ConferenceQA.read_cef(args)
dir = "/Users/hzw/Desktop/desktop/code/ConferenceQA/dataset/WWW/BM25"
for file_name in [
"extraction_atomic",
"extraction_complex",
"reasoning_atomic",
"reasoning_complex",
]:
with open(os.path.join(dir, f'{file_name}.json')) as f:
answer_results = json.load(f)
gold = qa.qas[file_name]
evaluate_answer(file_name=file_name, results=answer_results, gold=gold)