-
Notifications
You must be signed in to change notification settings - Fork 134
/
data_utils.py
124 lines (102 loc) · 4.1 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import absolute_import
import os
import re
import numpy as np
def load_task(data_dir, task_id, only_supporting=False):
'''Load the nth task. There are 20 tasks in total.
Returns a tuple containing the training and testing data for the task.
'''
assert task_id > 0 and task_id < 21
files = os.listdir(data_dir)
files = [os.path.join(data_dir, f) for f in files]
s = 'qa{}_'.format(task_id)
train_file = [f for f in files if s in f and 'train' in f][0]
test_file = [f for f in files if s in f and 'test' in f][0]
train_data = get_stories(train_file, only_supporting)
test_data = get_stories(test_file, only_supporting)
return train_data, test_data
def tokenize(sent):
'''Return the tokens of a sentence including punctuation.
>>> tokenize('Bob dropped the apple. Where is the apple?')
['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?']
'''
return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()]
def parse_stories(lines, only_supporting=False):
'''Parse stories provided in the bAbI tasks format
If only_supporting is true, only the sentences that support the answer are kept.
'''
data = []
story = []
for line in lines:
line = str.lower(line)
nid, line = line.split(' ', 1)
nid = int(nid)
if nid == 1:
story = []
if '\t' in line: # question
q, a, supporting = line.split('\t')
q = tokenize(q)
#a = tokenize(a)
# answer is one vocab word even if it's actually multiple words
a = [a]
substory = None
# remove question marks
if q[-1] == "?":
q = q[:-1]
if only_supporting:
# Only select the related substory
supporting = map(int, supporting.split())
substory = [story[i - 1] for i in supporting]
else:
# Provide all the substories
substory = [x for x in story if x]
data.append((substory, q, a))
story.append('')
else: # regular sentence
# remove periods
sent = tokenize(line)
if sent[-1] == ".":
sent = sent[:-1]
story.append(sent)
return data
def get_stories(f, only_supporting=False):
'''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story.
If max_length is supplied, any stories longer than max_length tokens will be discarded.
'''
with open(f) as f:
return parse_stories(f.readlines(), only_supporting=only_supporting)
def vectorize_data(data, word_idx, sentence_size, memory_size):
"""
Vectorize stories and queries.
If a sentence length < sentence_size, the sentence will be padded with 0's.
If a story length < memory_size, the story will be padded with empty memories.
Empty memories are 1-D arrays of length sentence_size filled with 0's.
The answer array is returned as a one-hot encoding.
"""
S = []
Q = []
A = []
for story, query, answer in data:
ss = []
for i, sentence in enumerate(story, 1):
ls = max(0, sentence_size - len(sentence))
ss.append([word_idx[w] for w in sentence] + [0] * ls)
# take only the most recent sentences that fit in memory
ss = ss[::-1][:memory_size][::-1]
# Make the last word of each sentence the time 'word' which
# corresponds to vector of lookup table
for i in range(len(ss)):
ss[i][-1] = len(word_idx) - memory_size - i + len(ss)
# pad to memory_size
lm = max(0, memory_size - len(ss))
for _ in range(lm):
ss.append([0] * sentence_size)
lq = max(0, sentence_size - len(query))
q = [word_idx[w] for w in query] + [0] * lq
y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word
for a in answer:
y[word_idx[a]] = 1
S.append(ss)
Q.append(q)
A.append(y)
return np.array(S), np.array(Q), np.array(A)