-
Notifications
You must be signed in to change notification settings - Fork 108
/
Copy pathload_data.py
73 lines (65 loc) · 3.23 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# -*- coding: utf-8 -*-
import csv
import torch
import torch.utils.data as tud
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence
TRAIN_DATA_PATH = '../data/train.tsv'
DEV_DATA_PATH = '../data/dev.tsv'
TOKENIZER_PATH = './bert-base-chinese'
MAX_LEN = 512
BATCH_SIZE = 32
PREFIX = '[MASK]满意。'
def collate_fn(batch_data):
"""
DataLoader所需的collate_fun函数,将数据处理成tensor形式
Args:
batch_data: batch数据
Returns:
"""
input_ids_list, attention_mask_list, labels_list = [], [], []
for instance in batch_data:
# 按照batch中的最大数据长度,对数据进行padding填充
input_ids_temp = instance["input_ids"]
attention_mask_temp = instance["mask"]
labels_temp = instance["labels"]
# 添加到对应的list中
input_ids_list.append(torch.tensor(input_ids_temp, dtype=torch.long))
attention_mask_list.append(torch.tensor(attention_mask_temp, dtype=torch.long))
labels_list.append(torch.tensor(labels_temp, dtype=torch.long))
# 使用pad_sequence函数,会将list中所有的tensor进行长度补全,补全到一个batch数据中的最大长度,补全元素为padding_value
return {"input_ids": pad_sequence(input_ids_list, batch_first=True, padding_value=0),
"attention_mask": pad_sequence(attention_mask_list, batch_first=True, padding_value=0),
"labels": pad_sequence(labels_list, batch_first=True, padding_value=-100)}
class BinarySentiDataset(tud.Dataset):
def __init__(self, data_path, tokenizer_path, max_len, prefix):
super(BinarySentiDataset, self).__init__()
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
self.max_len = max_len
self.prefix = prefix
self.pos_id = self.tokenizer.convert_tokens_to_ids('很')
self.neg_id = self.tokenizer.convert_tokens_to_ids('不')
self.data_set = []
with open (data_path, 'r', encoding='utf8') as rf:
r = csv.reader(rf, delimiter='\t')
next(r)
for row in r:
text = self.prefix + row[2]
input_ids = self.tokenizer.encode(text)
if len(input_ids) > self.max_len:
input_ids = input_ids[:self.max_len]
target = int(row[1])
if target == 0:
labels = [self.neg_id if idx == 103 else -100 for idx in input_ids]
else:
labels = [self.pos_id if idx == 103 else -100 for idx in input_ids]
mask = [1] * len(input_ids)
self.data_set.append({"input_ids": input_ids, "mask": mask, "labels": labels})
def __len__(self):
return len(self.data_set)
def __getitem__(self, idx):
return self.data_set[idx]
traindataset = BinarySentiDataset(TRAIN_DATA_PATH, TOKENIZER_PATH, MAX_LEN, PREFIX)
traindataloader = tud.DataLoader(traindataset, BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
valdataset = BinarySentiDataset(DEV_DATA_PATH, TOKENIZER_PATH, MAX_LEN, PREFIX)
valdataloader = tud.DataLoader(valdataset, BATCH_SIZE, shuffle=False, collate_fn=collate_fn)