-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrnn_data.py
124 lines (93 loc) · 3.37 KB
/
rnn_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, IterableDataset
import random
import math
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
from config import config
from train import Trainer
import string
import unicodedata
all_letters = string.ascii_letters + " .,;'-"
# Turn a Unicode string to plain ASCII, thanks to http://stackoverflow.com/a/518232/2809427
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in all_letters
)
class NamesDataset(Dataset):
samples: List[Tuple[str, str]]
char2index: Dict[str, int]
index2char: Dict[int, str]
langs: List[str]
def __init__(self):
super().__init__()
self.samples = []
self.char2index = {'PAD': 0}
self.index2char = {0: 'PAD'}
self.langs = []
self.read_samples()
self.index_samples()
random.shuffle(self.samples)
def read_samples(self):
dataset_path: Path = Path(__file__).parent / 'data' / 'names'
# Autocompletion hack
file: Path
for file in dataset_path.iterdir():
# Get language
lang: str = file.stem
self.langs.append(lang)
# Autocompletion hack
line: str
# Get all names of this language
for line in file.open('r'):
# Remove whitespace
name = line.strip()
# Remove unicode symbols
name = unicode_to_ascii(name)
# Store sample
sample = (name, lang)
self.samples.append(sample)
def index_samples(self):
for (name, lang) in self.samples:
chars = set(name)
for char in chars:
if char not in self.char2index:
index = len(self.char2index)
self.char2index[char] = index
self.index2char[index] = char
def __getitem__(self, index: int) -> Tuple[int, int]:
sample = self.samples[index]
sequence = [self.char2index[char] for char in sample[0]]
lang_id = self.langs.index(sample[1])
return (sequence, lang_id)
def __len__(self) -> int:
return len(self.samples)
# Common DataLoader utils
def pad_seq(seq, max_length):
"""
Return padded sequences of one size.
"""
seq += [0 for i in range(max_length - len(seq))]
return seq
def transform_batch(samples: List[Tuple[int, int]]) -> Tuple[Tensor, Tensor]:
"""
1. Sort sequences by length (may be not of equal sized)
2. Pad sequences to max length
"""
seqs = [sample[0] for sample in samples]
langs = [sample[1] for sample in samples]
# For sequences in batch, get array of lengths and pad with 0 1
seq_lengths = [len(s) for s in seqs]
seqs_padded = [pad_seq(s, max(seq_lengths)) for s in seqs]
# Turn padded array into (batch_size x max_len) tensor
seqs_tensor = torch.LongTensor(seqs_padded)
langs_tensor = torch.LongTensor(langs)
# lengths_tensor = torch.LongTensor(seq_lengths)
# Send to device (cpu or cuda)
seqs_tensor = seqs_tensor.to(config.device)
langs_tensor = langs_tensor.to(config.device)
# lengths_tensor = lengths_tensor.to(config.device)
return seqs_tensor, langs_tensor