forked from spro/practical-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
74 lines (59 loc) · 2.02 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Practical PyTorch: Generating Names with a Conditional Character-Level RNN
# https://github.com/spro/practical-pytorch
import glob
import unicodedata
import string
import random
import time
import math
import torch
from torch.autograd import Variable
# Preparing the Data
all_letters = string.ascii_letters + " .,;'-"
n_letters = len(all_letters) + 1 # Plus EOS marker
EOS = n_letters - 1
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in all_letters
)
def read_lines(filename):
lines = open(filename).read().strip().split('\n')
return [unicode_to_ascii(line) for line in lines]
category_lines = {}
all_categories = []
for filename in glob.glob('../data/names/*.txt'):
category = filename.split('/')[-1].split('.')[0]
all_categories.append(category)
lines = read_lines(filename)
category_lines[category] = lines
n_categories = len(all_categories)
# Preparing for Training
def random_training_pair():
category = random.choice(all_categories)
line = random.choice(category_lines[category])
return category, line
def make_category_input(category):
li = all_categories.index(category)
tensor = torch.zeros(1, n_categories)
tensor[0][li] = 1
return Variable(tensor)
def make_chars_input(chars):
tensor = torch.zeros(len(chars), n_letters)
for ci in range(len(chars)):
char = chars[ci]
tensor[ci][all_letters.find(char)] = 1
tensor = tensor.view(-1, 1, n_letters)
return Variable(tensor)
def make_target(line):
letter_indexes = [all_letters.find(line[li]) for li in range(1, len(line))]
letter_indexes.append(n_letters - 1) # EOS
tensor = torch.LongTensor(letter_indexes)
return Variable(tensor)
def random_training_set():
category, line = random_training_pair()
category_input = make_category_input(category)
line_input = make_chars_input(line)
line_target = make_target(line)
return category_input, line_input, line_target