-
Notifications
You must be signed in to change notification settings - Fork 1
/
SS_dataset.py
118 lines (97 loc) · 3.24 KB
/
SS_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import numpy
import os, gc
import cPickle
import copy
import logging
import threading
import Queue
import collections
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.FileHandler('./log/' + __name__))
class SSFetcher(threading.Thread):
def __init__(self, parent):
threading.Thread.__init__(self)
self.parent = parent
self.rng = numpy.random.RandomState(self.parent.seed)
self.indexes = numpy.arange(parent.data_len)
def run(self):
diter = self.parent
self.rng.shuffle(self.indexes)
offset = 0
while not diter.exit_flag:
last_batch = False
data_x_y = []
while len(data_x_y) < diter.batch_size:
if offset == diter.data_len:
if not diter.use_infinite_loop:
last_batch = True
break
else:
# Infinite loop here, we reshuffle the indexes
# and reset the offset
self.rng.shuffle(self.indexes)
offset = 0
index = self.indexes[offset]
s = diter.data[index] # s is a pair (data_x, data_y)
s = (map(int, s[0].split()),
map(int, s[1].split()))
offset += 1
# Append only if it is shorter than max_len
if diter.max_len == -1 or len(s) <= diter.max_len:
data_x_y.append(s)
if len(data_x_y):
diter.queue.put(data_x_y)
if last_batch:
diter.queue.put(None)
return
class SSIterator(object):
def __init__(self,
data_file,
batch_size,
seed=1234,
max_len=-1,
use_infinite_loop=True,
dtype="int32"):
self.data_file = data_file
self.batch_size = batch_size
args = locals()
args.pop("self")
self.__dict__.update(args)
self.load_files()
self.exit_flag = False
def load_files(self):
self.data = []
load_cnt = 0
print("SSFetch loading data ...")
with open(self.data_file) as f_in:
while True:
l1 = f_in.readline()
if l1 == "":
break
l2 = f_in.readline()
self.data.append((l1, l2))
load_cnt = load_cnt + 1
if load_cnt % 100 == 0:
print(load_cnt, "loaded ...")
self.data_len = len(self.data)
logger.debug('Data len is %d' % self.data_len)
def start(self):
self.exit_flag = False
self.queue = Queue.Queue(maxsize=1000)
self.gather = SSFetcher(self)
self.gather.daemon = True
self.gather.start()
def __del__(self):
if hasattr(self, 'gather'):
self.gather.exitFlag = True
self.gather.join()
def __iter__(self):
return self
def next(self):
if self.exit_flag:
return None
batch = self.queue.get()
if not batch:
self.exit_flag = True
return batch