-
Notifications
You must be signed in to change notification settings - Fork 0
/
serialize.py
60 lines (41 loc) · 1.67 KB
/
serialize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import pickle
import multiprocessing
import tensorflow as tf
import constants
import generate
def get_nth_serialized_train_filepath(n):
return os.path.join(os.path.dirname(__file__), 'spectrogram_data', f'train{n}.tfrecord')
def to_feature_list(array_2d):
float_lists = [tf.train.Feature(float_list=tf.train.FloatList(value=row)) for row in array_2d]
return tf.train.FeatureList(feature=float_lists)
def serialize(array):
feature_list = {'spec': to_feature_list(array)}
example = tf.train.SequenceExample(feature_lists=tf.train.FeatureLists(feature_list=feature_list))
return example.SerializeToString()
def write_tfrecord(n):
raw_filepath = generate.get_nth_train_filepath(n)
serialized_filepath = get_nth_serialized_train_filepath(n)
with open(raw_filepath, 'rb') as raw_file:
with tf.io.TFRecordWriter(serialized_filepath) as writer:
try:
while True:
data = pickle.load(raw_file)
writer.write(serialize(data))
except EOFError:
pass
def write_all_tfrecords():
workers = []
for i in range(generate.NUM_THREADS):
workers.append(multiprocessing.Process(target=write_tfrecord, args=(i,)))
for worker in workers:
worker.start()
for worker in workers:
worker.join()
feature_description = {
'spec': tf.io.FixedLenSequenceFeature(constants.NUM_BINS, tf.float32)
}
def decode(example):
return tf.io.parse_single_sequence_example(example, sequence_features=feature_description)[1].get('spec')
if __name__ == '__main__':
write_all_tfrecords()