-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
236 lines (206 loc) · 10.4 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
from askwikidata import AskWikidata
from typing import Callable, List, Dict, Optional
from dataclasses import dataclass, field, asdict
from pprint import pprint
import json
import datetime
def kai_wegner_date_correct(text: str):
t = text.lower()
return "27" in t and "2023" in t and ("4" in t or "april" in t)
def correct(answer: str, expected: str | list | Callable) -> bool:
answer_lower = answer.lower()
expected_lower = None
if isinstance(expected, str):
expected_lower = expected.lower()
if isinstance(expected, list):
expected_lower = [answer_part.lower() for answer_part in expected]
return (
(isinstance(expected_lower, str) and expected_lower in answer_lower)
or (
isinstance(expected_lower, list)
and all((answer_part) in answer_lower for answer_part in expected_lower)
)
or (isinstance(expected, Callable) and expected(answer_lower) == True)
)
# fmt: off
quiz = [
{"q": "What is Berlin?", "a": ["capital", "germany", "city"]},
{"q": "Mayor of Berlin", "a": "Kai Wegner"},
{"q": "mayor berlin", "a": "Kai Wegner"},
{"q": "Who is the current mayor of Berlin?", "a": "Kai Wegner"},
{"q": "since when is kay wegner mayor of berlin?", "a": kai_wegner_date_correct},
{"q": "Who was the mayor of Berlin in 2001?", "a": "Klaus Wowereit"},
{"q": "What is the population of Berlin?", "a": "3755251"},
{"q": "Can you name all twin cities of Berlin?", "a": [ "Los Angeles", "Paris", "Madrid", "Istanbul", "Warsaw", "Moscow", "City of Brussels", "Budapest", "Tashkent", "Mexico City", "Beijing", "Jakarta", "Tokyo", "Buenos Aires", "Prague", "Windhoek", "London", "Sofia", "Tehran", "Seville", "Copenhagen", "Kyiv", "Brasília", "Santo Domingo", "Algiers", "Rio de Janeiro", ], },
{"q": "Which River runs through Berlin?", "a": "Spree"},
{"q": "Who is the current mayor of Paris?", "a": "Anne Hidalgo"},
{"q": "What is the population of Paris?", "a": "2145906"},
{"q": "Can you name all twin cities of Paris?", "a": [ "Rome", "Tokyo", "Kyoto", "Berlin", "Ramallah", "Seoul", "Cairo", "Chicago", "Torreón", "San Francisco", "Kyiv", "Washington, D.C.", "Marrakesh", "Porto Alegre", "Dubai", "Beijing", "Mexico City", "Saint Petersburg", ], },
{"q": "Which River runs through Paris?", "a": "Seine"},
{"q": "Who is the current mayor of London?", "a": "Sadiq Khan"},
{"q": "What is the population of London?", "a": "8799728"},
{"q": "Can you name all twin cities of London?", "a": [ "Berlin", "Mumbai", "New York City", "Algiers", "Sofia", "Moscow", "Tokyo", "Beijing", "Karachi", "Zagreb", "Tehran", "Arequipa", "Delhi", "Bogotá", "Johannesburg", "Kuala Lumpur", "Oslo", "Sylhet", "Shanghai", "Baku", "Buenos Aires", "Istanbul", "Los Angeles", "Podgorica", "New Delhi", "Phnom Penh", "Jakarta", "Amsterdam", "Bucharest", "Santo Domingo", "La Paz", "Mexico City", ], },
{"q": "Which River runs through London?", "a": "River Thames"},
{"q": "Who is the current mayor of Prague?", "a": "Bohuslav Svoboda"},
{"q": "What is the population of Prague?", "a": "1357326"},
{"q": "Can you name all twin cities of Prague?", "a": [ "Berlin", "Copenhagen", "Miami-Dade County", "Nuremberg", "Luxembourg", "Guangzhou", "Hamburg", "Helsinki", "Nîmes", "Prešov", "Rosh HaAyin", "Teramo", "Bamberg", "City of Brussels", "Frankfurt", "Jerusalem", "Moscow", "Saint Petersburg", "Chicago", "Taipei", "Terni", "Ferrara", "Trento", "Monza", "Lecce", "Naples", "Vilnius", "Istanbul", "Sofia", "Buenos Aires", "Athens", "Bratislava", "Madrid", "Tunis", "Brussels metropolitan area", "Amsterdam", "Phoenix", "Tirana", "Kyoto", "Cali", "Drancy", "Beijing", "Shanghai", "Tbilisi", ], },
{"q": "Can you name all twinned administrative bodies of Prague?", "a": [ "Berlin", "Copenhagen", "Miami-Dade County", "Nuremberg", "Luxembourg", "Guangzhou", "Hamburg", "Helsinki", "Nîmes", "Prešov", "Rosh HaAyin", "Teramo", "Bamberg", "City of Brussels", "Frankfurt", "Jerusalem", "Moscow", "Saint Petersburg", "Chicago", "Taipei", "Terni", "Ferrara", "Trento", "Monza", "Lecce", "Naples", "Vilnius", "Istanbul", "Sofia", "Buenos Aires", "Athens", "Bratislava", "Madrid", "Tunis", "Brussels metropolitan area", "Amsterdam", "Phoenix", "Tirana", "Kyoto", "Cali", "Drancy", "Beijing", "Shanghai", "Tbilisi", ], },
{"q": "Which River runs through Prague?", "a": "Vltava"},
{"q": "What is the elevation of Bobo-Dioulasso?", "a": "445"},
]
# fmt: on
configurations = [
{
"chunk_size": 1280,
"chunk_overlap": 0,
"index_trees": 10,
"retrieval_chunks": 16,
"context_chunks": 5,
"embedding_model_name": "BAAI/bge-small-en-v1.5",
# "embedding_model_name": "BAAI/bge-base-en-v1.5",
# "embedding_model_name": "BAAI/bge-large-en-v1.5",
"reranker_model_name": "BAAI/bge-reranker-base",
# "reranker_model_name": "BAAI/bge-reranker-large",
# "qa_model_url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1",
"qa_model_url": "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf",
},
{
"chunk_size": 1280,
"chunk_overlap": 0,
"index_trees": 10,
"retrieval_chunks": 16,
"context_chunks": 5,
# "embedding_model_name": "BAAI/bge-small-en-v1.5",
"embedding_model_name": "BAAI/bge-base-en-v1.5",
# "embedding_model_name": "BAAI/bge-large-en-v1.5",
"reranker_model_name": "BAAI/bge-reranker-base",
# "reranker_model_name": "BAAI/bge-reranker-large",
# "qa_model_url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1",
"qa_model_url": "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf",
},
{
"chunk_size": 1280,
"chunk_overlap": 0,
"index_trees": 10,
"retrieval_chunks": 16,
"context_chunks": 5,
# "embedding_model_name": "BAAI/bge-small-en-v1.5",
# "embedding_model_name": "BAAI/bge-base-en-v1.5",
"embedding_model_name": "BAAI/bge-large-en-v1.5",
"reranker_model_name": "BAAI/bge-reranker-base",
# "reranker_model_name": "BAAI/bge-reranker-large",
# "qa_model_url": "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1",
"qa_model_url": "https://api-inference.huggingface.co/models/meta-llama/Llama-2-7b-chat-hf",
},
]
@dataclass
class QERCA:
question: str
expected_answer: str | List[str]
retrieved_context: str
reranked_context: Optional[str] = None
answer: Optional[str] = None
@dataclass
class EvalResult:
config: Dict[str, str | int]
total_chunks: int
total_questions: int
correct_retrievals: int = 0
correct_reranks: int = 0
correct_answers: int = 0
correct_answers_plain: int = 0
failed_retrieval_questions: List[QERCA] = field(default_factory=list)
failed_rerank_questions: List[QERCA] = field(default_factory=list)
failed_answer_questions: List[QERCA] = field(default_factory=list)
datatime: str = datetime.datetime.now().isoformat()
eval_results: List[EvalResult] = []
for config in configurations:
pprint(config)
askwikidata = AskWikidata(**config)
askwikidata.setup()
# askwikidata.print_data()
eval_result = EvalResult(
config, total_questions=len(quiz), total_chunks=len(askwikidata.df)
)
eval_results.append(eval_result)
for i, q in enumerate(quiz):
question = q["q"]
expected_answer = q["a"]
print("")
print("Question:", question)
print("Expected answer:", expected_answer)
answer_plain = askwikidata.llm_generate_plain(question)
if correct(answer_plain, expected_answer):
eval_result.correct_answers_plain += 1
print("🙈 Plain Answer correct:", answer_plain)
else:
print("👍 Plain Answer wrong:", answer_plain)
retrieved = askwikidata.retrieve(question)
retrieved_context = askwikidata.context(retrieved)
retrieved_context_lower = retrieved_context.lower()
if correct(retrieved_context, expected_answer):
eval_result.correct_retrievals += 1
print("✅ Retrieved Context")
else:
print("‼️ WRONG Retrieved Context")
eval_result.failed_retrieval_questions.append(
QERCA(
question,
"<function>"
if isinstance(expected_answer, Callable)
else expected_answer,
retrieved_context,
)
)
continue
reranked, rerank_time = askwikidata.rerank(question, retrieved)
print(f" {int(rerank_time)} seconds.")
reranked_context = askwikidata.context(reranked)
if correct(reranked_context, expected_answer):
eval_result.correct_reranks += 1
print("✅ Reranked Context")
else:
print("‼️ WRONG Reranked Context")
eval_result.failed_rerank_questions.append(
QERCA(
question,
"<function>"
if isinstance(expected_answer, Callable)
else expected_answer,
retrieved_context,
reranked_context,
)
)
continue
answer = askwikidata.llm_generate(question, reranked)
if correct(answer, expected_answer):
eval_result.correct_answers += 1
print("✅ Answer:", answer)
else:
print("‼️ WRONG Answer:", answer)
eval_result.failed_answer_questions
eval_result.failed_answer_questions.append(
QERCA(
question,
"<function>"
if isinstance(expected_answer, Callable)
else expected_answer,
retrieved_context,
reranked_context,
answer,
)
)
for eval_result in eval_results:
print("")
print("***************************************")
print(" 🔍 Results 🔎\n")
print("\n")
pprint(eval_result.config)
print("\n")
pprint(eval_result, width=120, depth=1)
print("\n")
print("***************************************")
print("\n")
with open(f"eval_results.json", "a") as file:
for eval_result in eval_results:
file.write(json.dumps(asdict(eval_result)) + "\n")