-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfeedback_dataset.py
66 lines (58 loc) · 2.23 KB
/
feedback_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import re
import pandas as pd
import pyarrow as pa
from datasets import Dataset
from pylatexenc.latex2text import LatexNodes2Text
latex_parser = LatexNodes2Text()
whitespace_re = re.compile(r"\s+")
def clean_whitespace(text: str):
return whitespace_re.sub(" ", text.strip())
def parse_text(text: str):
return clean_whitespace(latex_parser.latex_to_text(text))
def get_raw_dataset():
df = pd.read_csv("data/pp_eedi_data_0912.csv")
valid = df["Script Clean"] | (df["Cleaned"] & ~df["Discuss Flag"])
valid = valid & ~df["Image Needed"]
df = df[valid]
# Convert LaTeX to unicode
# df["question"] = df["question"].apply(parse_text)
# for i in range(1, 5):
# df[f"Answer{i}"] = df[f"Answer{i}"].apply(parse_text)
# df[f"Explanation{i}"] = df[f"Explanation{i}"].apply(clean_whitespace)
return df
def expand_rows(df: pd.DataFrame):
result = []
for row_idx, row in df.iterrows():
correct_answer, explanation = None, None
distractors = []
for i in range(1, 5):
try:
if i == int(row["CorrectAnswer"]):
correct_answer = row[f"Answer{i}"]
explanation = row[f"Explanation{i}"]
else:
distractors.append((row[f"Answer{i}"], row[f"Explanation{i}"]))
except Exception as e:
print(e)
if not correct_answer:
print(f"No correct answer for {row_idx} - skipping")
continue
for distractor, feedback in distractors:
result.append(
{
"qid": row["id"],
"question": row["question"],
"construct_id": row["ConstructId"],
"correct_answer": correct_answer,
"explanation": explanation,
"distractor": distractor,
"feedback": feedback,
}
)
return pd.DataFrame(result)
def extract_feedback(output: str):
return re.search(r"Feedback: (.*)$", output).group(1).strip()
def load_pd_dataset(filename: str):
return expand_rows(pd.read_csv(filename))
def load_hf_dataset(filename: str):
return Dataset(pa.Table.from_pandas(load_pd_dataset(filename)))