|
| 1 | +import torch |
| 2 | +import pandas as pd |
| 3 | +import transformers |
| 4 | +import sentencepiece as spm |
| 5 | +from utils import sentencepiece_pb2 |
| 6 | + |
| 7 | + |
| 8 | +class OffsetTokenizer(): |
| 9 | + def __init__(self, path_model): |
| 10 | + self.spt = sentencepiece_pb2.SentencePieceText() |
| 11 | + self.sp = spm.SentencePieceProcessor(model_file=path_model) |
| 12 | + |
| 13 | + def encode(self, text, lower=True): |
| 14 | + if lower: |
| 15 | + text = text.lower() |
| 16 | + offset = [] |
| 17 | + ids = [] |
| 18 | + self.spt.ParseFromString(self.sp.encode_as_serialized_proto(text)) |
| 19 | + |
| 20 | + for piece in self.spt.pieces: |
| 21 | + offset.append((piece.begin, piece.end)) |
| 22 | + ids.append(piece.id) |
| 23 | + |
| 24 | + return {"token": self.sp.encode_as_pieces(text), "ids": ids, "offsets": offset} |
| 25 | + |
| 26 | + |
| 27 | +class AlbertTweetDataset(torch.utils.data.Dataset): |
| 28 | + def __init__(self, df, max_len=128): |
| 29 | + # Dataframe dữ liệu |
| 30 | + self.df = df |
| 31 | + # Độ dài tối đa của câu |
| 32 | + self.max_len = max_len |
| 33 | + # Nhãn |
| 34 | + self.labeled = 'selected_text' in df |
| 35 | + # Khởi tạo mã hóa SentencePiece cho Albert |
| 36 | + self.tokenizer = OffsetTokenizer( |
| 37 | + path_model='./albert.torch/albert-large-v2/spiece.model') |
| 38 | + |
| 39 | + # self.sp = spm.SentencePieceProcessor( |
| 40 | + # model_file='./albert.torch/albert-large-v2/spiece.model') |
| 41 | + # self.spt = sentencepiece_pb2.SentencePieceText() |
| 42 | + |
| 43 | + # ========= TEST CODE ======= |
| 44 | + # self.tokenizer = transformers.AlbertTokenizer.from_pretrained( |
| 45 | + # './albert.torch/albert-large-v2/spiece.model', |
| 46 | + # do_lower_case=True) |
| 47 | + # print(self.tokenizer.tokenize(" Nguyen Duc Thang")) |
| 48 | + # print(self.tokenizer.encode("Nguyen Duc Thang")) |
| 49 | + # print(self.tokenizer.decode([2, 20449, 13, 8484, 119, 263, 3, 0])) |
| 50 | + # print(self.spt.ParseFromString( |
| 51 | + # self.sp.encode_as_serialized_proto("Nguyen Duc Thang".lower()))) |
| 52 | + # offset = [] |
| 53 | + # ids = [] |
| 54 | + |
| 55 | + # for piece in self.spt.pieces: |
| 56 | + # offset.append((piece.begin, piece.end)) |
| 57 | + # ids.append(piece.id) |
| 58 | + # print(ids) |
| 59 | + # print(offset) |
| 60 | + # ======= END TEST ======= |
| 61 | + |
| 62 | + def __len__(self): |
| 63 | + """ Trả về độ dài của DataFrame """ |
| 64 | + return len(self.df) |
| 65 | + |
| 66 | + def get_input_data(self, row): |
| 67 | + """ |
| 68 | + Tạo sample input cho 1 dòng dữ liệu |
| 69 | + - Input: [CLS] <sentiment>[SEP]token11 token12 ... [SEP][pad][pad] |
| 70 | + """ |
| 71 | + # Thêm khoảng trắng vào đầu câu đầu vào |
| 72 | + tweet = " " + " ".join(row.text.lower().split()) |
| 73 | + # Mã hóa câu đầu vào |
| 74 | + encoding = self.tokenizer.encode(tweet) |
| 75 | + # Mã hóa sentiment |
| 76 | + sentiment_id = self.tokenizer.encode(row.sentiment)["ids"] |
| 77 | + # 2 là CLS, 3 là SEP, 0 là <pad> |
| 78 | + ids = [2] + sentiment_id + [3] + encoding["ids"] + [3] |
| 79 | + # token type ids |
| 80 | + token_type_ids = [0] * 3 + [1] * (len(encoding["ids"]) + 1) |
| 81 | + # offset là vị trí các token của câu ban đầu |
| 82 | + offsets = [(0, 0)] * 3 + encoding["offsets"] + [(0, 0)] |
| 83 | + |
| 84 | + # Thêm các token pad cho viền |
| 85 | + pad_len = self.max_len - len(ids) |
| 86 | + if pad_len > 0: |
| 87 | + ids += [0] * pad_len |
| 88 | + offsets += [(0, 0)] * pad_len |
| 89 | + token_type_ids += [0] * pad_len |
| 90 | + ids = torch.tensor(ids) |
| 91 | + # Tạo mặt nạ attention, đánh dấu 1 cho toàn bộ câu đầu vào |
| 92 | + # Trừ các phần là <pad> |
| 93 | + masks = torch.where(ids != 1, torch.tensor(1), torch.tensor(0)) |
| 94 | + offsets = torch.tensor(offsets) |
| 95 | + token_type_ids = torch.tensor(token_type_ids) |
| 96 | + |
| 97 | + return ids, masks, tweet, offsets, token_type_ids |
| 98 | + |
| 99 | + def get_target_idx(self, row, tweet, offsets): |
| 100 | + selected_text = " " + " ".join(row.selected_text.lower().split()) |
| 101 | + |
| 102 | + len_st = len(selected_text) - 1 |
| 103 | + # Vị trí bắt đầu và kết thúc của selectec_text trong tweet |
| 104 | + idx0, idx1 = None, None |
| 105 | + |
| 106 | + for ind in (i for i, e in enumerate(tweet) if e == selected_text[1]): |
| 107 | + if " " + tweet[ind:ind+len_st] == selected_text: |
| 108 | + idx0 = ind |
| 109 | + idx1 = ind + len_st - 1 |
| 110 | + |
| 111 | + # Đánh dấu những vị trí mà có ký tự của selected_text là 1 |
| 112 | + char_targets = [0] * len(tweet) |
| 113 | + if idx0 != None and idx1 != None: |
| 114 | + for ct in range(idx0, idx1 + 1): |
| 115 | + char_targets[ct] = 1 |
| 116 | + |
| 117 | + # Đánh dấu những token chứa selected_text |
| 118 | + target_idx = [] |
| 119 | + for j, (offset1, offset2) in enumerate(offsets): |
| 120 | + if sum(char_targets[offset1:offset2]) > 0: |
| 121 | + target_idx.append(j) |
| 122 | + |
| 123 | + # Token bắt đầu và token kết thúc của selected_text |
| 124 | + start_idx = target_idx[0] |
| 125 | + end_idx = target_idx[-1] |
| 126 | + |
| 127 | + return start_idx, end_idx |
| 128 | + |
| 129 | + def __getitem__(self, index): |
| 130 | + """ |
| 131 | + Chuyển đổi hàng dữ liệu thứ index trong dataFrame |
| 132 | + sang dữ liệu đầu vào của mô hình |
| 133 | + Các thuộc tính cho dữ liệu đầu vafo: |
| 134 | + - ids |
| 135 | + - masks |
| 136 | + - tweet |
| 137 | + - offsets |
| 138 | + - token_type_ids |
| 139 | + - start_idx |
| 140 | + - end_idx |
| 141 | + """ |
| 142 | + data = {} |
| 143 | + row = self.df.iloc[index] |
| 144 | + |
| 145 | + ids, masks, tweet, offsets, token_type_ids = self.get_input_data(row) |
| 146 | + data['ids'] = ids |
| 147 | + data['masks'] = masks |
| 148 | + data['tweet'] = tweet |
| 149 | + data['offsets'] = offsets |
| 150 | + data['token_type_ids'] = token_type_ids |
| 151 | + |
| 152 | + if self.labeled: |
| 153 | + start_idx, end_idx = self.get_target_idx(row, tweet, offsets) |
| 154 | + data['start_idx'] = start_idx |
| 155 | + data['end_idx'] = end_idx |
| 156 | + |
| 157 | + return data |
0 commit comments