-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
32 lines (22 loc) · 899 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import argparse
from typing import Dict
from datasets import DatasetDict, load_dataset
def load_dataset_from_path(path: str):
train_set = load_dataset("json", data_files=path, split='train[:80%]')
dev_set = load_dataset("json", data_files=path, split='train[80%:90%]')
test_set = load_dataset("json", data_files=path, split='train[90%:]')
dataset = DatasetDict({
"train": train_set,
"validation": dev_set,
"test": test_set,
})
corpus: Dict[int, str] = {} # answer id -> answer str
for k in dataset:
for item in dataset[k]:
corpus[item["answer_id"]] = item["answer"]
return dataset
if __name__ == "__main__":
dataset = load_dataset_from_path("../data/qa.en.c.json")
print("train size", len(dataset['train']))
print("dev size:", len(dataset['validation']))
print("test size:", len(dataset['test']))