forked from tf-encrypted/tf-encrypted
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv_convert.py
69 lines (52 loc) · 2.25 KB
/
conv_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Data processing helpers."""
#
# Based on:
# - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/how_tos/reading_data/convert_to_records.py
# - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
#
from functools import partial
import tensorflow as tf
def encode_image(value):
"""Encode images into a tf.train.Feature for a TFRecord."""
bytes_list = tf.train.BytesList(value=[value.tostring()])
return tf.train.Feature(bytes_list=bytes_list)
def decode_image(value, flattened):
"""Decode the image from a tf.train.Feature in a TFRecord."""
image = tf.decode_raw(value, tf.uint8)
if not flattened:
image = tf.reshape(image, (1, 28, 28))
return image
def encode_label(value):
"""Encode a label into a tf.train.Feature for a TFRecord."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def decode_label(value):
"""Decode the label from a tf.train.Feature in a TFRecord."""
return tf.cast(value, tf.int32)
def encode(image, label):
"""Encode an instance as a tf.train.Example for a TFRecord."""
feature_dict = {'image': encode_image(image), 'label': encode_label(label)}
features = tf.train.Features(feature=feature_dict)
return tf.train.Example(features=features)
def decode(serialized_example, flattened):
"""Decode an instance from a tf.train.Example in a TFRecord."""
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64)
})
image = decode_image(features['image'], flattened)
label = decode_label(features['label'])
return image, label
def normalize(image, label):
"""Standardization of MNIST images."""
x = tf.cast(image, tf.float32) / 255.
image = (x - 0.1307) / 0.3081 # image = (x - mean) / std
return image, label
def get_data_from_tfrecord(filename, batch_size: int, flattened=False):
"""Construct a TFRecordDataset iterator."""
decoder = partial(decode, flattened=flattened)
return tf.data.TFRecordDataset([filename]) \
.map(decoder) \
.map(normalize) \
.repeat() \
.batch(batch_size) \
.make_one_shot_iterator()