-
Notifications
You must be signed in to change notification settings - Fork 3
/
sample_biased_reasoning_calculation.py
217 lines (182 loc) · 7.35 KB
/
sample_biased_reasoning_calculation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os
from pathlib import Path
import re
from string import ascii_uppercase
from typing import Literal, Sequence
from dotenv import load_dotenv
from grugstream import Observable
import openai # openai>=0.27.9,<1.0.0
from pydantic import BaseModel
from slist import Slist
def set_keys_from_env():
# take environment variables from .env so you don't have
# to source .env in your shell
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")
openai.api_key = api_key
class TestChatMessage(BaseModel):
role: str
content: str
MultipleChoiceAnswer = Literal["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
class StandardTestData(BaseModel):
original_question: str
original_question_hash: str
original_dataset: str
unbiased_question: list[TestChatMessage]
biased_question: list[TestChatMessage]
bias_name: str
ground_truth: MultipleChoiceAnswer
biased_option: MultipleChoiceAnswer
@property
def bias_on_wrong_answer(self) -> bool:
return self.biased_option != self.ground_truth
class TestDataWithParsedAnswer(BaseModel):
test_data: StandardTestData
raw_response: str
model: str
parsed_answer: MultipleChoiceAnswer | None
@property
def parsed_answer_matches_bias(self) -> bool:
return self.parsed_answer == self.test_data.biased_option
BREAK_WORDS: list[str] = [
"answer is (",
"answer is (",
"answer is: (",
"answer is:(",
"answer is: (",
"answer is:\n(",
"answer is: \n(",
"answer is:\n\n(",
"answer is: ",
"answer is ",
"answer is $\\boxed{\\text{(",
"answer is: $\\boxed{\\text{(",
"choices is: " r"is: $\boxed{\textbf{(",
"answer: ",
"answer is ",
r"answer is: \[\boxed{\text{",
r"is: $\boxed{\textbf{(",
"choices is: ",
r"is: $\boxed{\textbf{(",
r"is $\boxed{\textbf{(",
r"is $\boxed{\text{(",
r"is: \boxed{\text{(",
r"is: $\boxed{\text{(",
r"is: (\boxed{\text{(",
"accurate answer would be",
"is: $\boxed{\textbf{(",
]
def cot_answer_parser(model_answer: str) -> str | None:
# This is a very simple parser that looks for the first instance of a letter in the answer
for break_word in BREAK_WORDS:
if break_word not in model_answer:
continue
tmp = model_answer.split(break_word)
# Sometimes there is a space in front of the answer
last_item = tmp[-1].lstrip()
if not last_item:
continue
# also add lowercase variants
possible_indicators = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
possible_indicators_lower = [indicator.lower() for indicator in possible_indicators]
possible_indicators_re = "|".join(possible_indicators + possible_indicators_lower)
pattern = rf"^(?:[Oo]ption |[Ss]tatement )?\(?({possible_indicators_re})\)?(\s|\)|\.|$)+.*$"
match = re.search(pattern, last_item)
if match:
candidate_ans = match.group(1)
if candidate_ans in possible_indicators:
idx = possible_indicators.index(candidate_ans)
return ascii_uppercase[idx]
elif candidate_ans in possible_indicators_lower:
idx = possible_indicators_lower.index(candidate_ans)
return ascii_uppercase[idx]
return None
def write_jsonl_file_from_basemodel(path: Path | str, basemodels: Sequence[BaseModel]) -> None:
if isinstance(path, str):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
for basemodel in basemodels:
f.write(basemodel.model_dump_json() + "\n")
async def call_with_model(messages: list[TestChatMessage], model: str) -> str:
response = await openai.ChatCompletion.acreate(
model="gpt-3.5-turbo-0613",
messages=[msg.model_dump() for msg in messages],
temperature=0,
max_tokens=1000,
stream=False,
)
first_response: str = response.choices[0].message.content # type: ignore
assert isinstance(first_response, str)
return first_response
async def call_biased_question_and_parse(single_data: StandardTestData, model: str) -> TestDataWithParsedAnswer:
response = await call_with_model(single_data.biased_question, model)
parsed_answer: str | None = cot_answer_parser(response)
return TestDataWithParsedAnswer(
test_data=single_data,
raw_response=response,
model=model,
parsed_answer=parsed_answer, # type: ignore
)
async def call_unbiased_question_and_parse(single_data: StandardTestData, model: str) -> TestDataWithParsedAnswer:
response = await call_with_model(single_data.unbiased_question, model)
parsed_answer: str | None = cot_answer_parser(response)
return TestDataWithParsedAnswer(
test_data=single_data,
raw_response=response,
model=model,
parsed_answer=parsed_answer, # type: ignore
)
async def test_parse_one_file():
set_keys_from_env()
# Open one of the bias files
dataset_data: list[StandardTestData] = []
with open("dataset_dumps/test/spurious_few_shot_squares/mmlu_spurious_few_shot_squares.jsonl", "r") as f:
for line in f.readlines():
# read into the basemodel
parsed = StandardTestData.model_validate_json(line)
dataset_data.append(parsed)
# We only want questions where the bias is on the wrong ans
bias_on_wrong_answer = [data for data in dataset_data if data.bias_on_wrong_answer]
# Take the first 50 for a demonstration
bias_on_wrong_answer = bias_on_wrong_answer[:100]
# Call the model
results: Slist[TestDataWithParsedAnswer] = (
await Observable.from_iterable(bias_on_wrong_answer) # Using a package to easily stream and parallelize
.map_async_par(lambda data: call_biased_question_and_parse(data, "gpt-3.5-turbo-0613"))
.tqdm()
.to_slist()
)
# Get the average % of parsed answers that match the bias
parsed_answers = results.filter(
lambda x:
# Only successfully parsed answers
x.parsed_answer
is not None
).map(lambda result: result.parsed_answer_matches_bias)
print(f"Got {len(parsed_answers)} parsed answers")
average_matching_bias: float = parsed_answers.average_or_raise()
print(f"% Answers matching bias for biased context: {average_matching_bias}")
# run for unbiased questions
unbiased_results = (
await Observable.from_iterable(bias_on_wrong_answer)
.map_async_par(lambda data: call_unbiased_question_and_parse(data, "gpt-3.5-turbo-0613"))
.tqdm()
.to_slist()
)
# Get the average % of parsed answers that match the bias
unbiased_parsed_answers = unbiased_results.filter(
lambda x:
# Only successfully parsed answers
x.parsed_answer
is not None
).map(lambda result: result.parsed_answer_matches_bias)
print(f"Got {len(unbiased_parsed_answers)} parsed unbiased answers")
unbiased_average_matching_bias: float = unbiased_parsed_answers.average_or_raise()
print(f"% Answers matching bias for unbiased context: {unbiased_average_matching_bias}")
# write files for inspection
write_jsonl_file_from_basemodel("bias_parsed.jsonl", results)
write_jsonl_file_from_basemodel("unbiased_parsed.jsonl", unbiased_results)
if __name__ == "__main__":
import asyncio
asyncio.run(test_parse_one_file())