-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinput_fn.py
154 lines (140 loc) · 6.83 KB
/
input_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tensorflow as tf
def parser(
single_image_filename_tensor, # A tensor containing a single filename
single_mask_filename_tensor, # A tensor containing a single filename
data_format, # If 'channels_first', reshape the image
initial_image_shape=[1280, 1918], # Initial shape of the image
target_image_shape=[1280, 1920], # Pad with zeros until target shape
num_channels=3): # Number of channels in the input image
# TODO: Consider examining performance if TF operations are replaced with
# pillow (or opencv) operations, such as Image.open or Image.resize
# Parse the original image into a tensor of shape == [1280, 1920, 3]
# with dtype == tf.float32 and RGB values in the range [0, 1].
image = tf.read_file(single_image_filename_tensor)
image = tf.image.decode_jpeg(image, channels=num_channels)
# After decoding, image.shape == [1280, 1918, 3]
top_padding = int((target_image_shape[0] - initial_image_shape[0]) / 2)
left_padding = int((target_image_shape[1] - initial_image_shape[1]) / 2)
image = tf.image.pad_to_bounding_box(
# Pads the image to the desired size by adding zeros.
image,
# Adds `offset_height` rows of zeros to the top
offset_height=top_padding,
# Adds `offset_width` rows of zeros to the left
offset_width=left_padding,
# Pad the image on the bottom until `target_height`
target_height=target_image_shape[0],
# Pad the iamge on the right until `target_width`
target_width=target_image_shape[1])
if data_format == 'channels_first':
# Extract each channel in an explicit way to prevent
# mangling the structure of the image.
c0 = image[:, :, 0]
c1 = image[:, :, 1]
c2 = image[:, :, 2]
# Recombine the channels
image = tf.stack([c0, c1, c2], axis=0)
image = tf.cast(image, tf.float32) / 255 # Convert and scale
# TODO: Does scaling provide any benefit? Can we test this?
# Parse the segmented image (mask) into a tensor of shape == [1280, 1920]
# with dtype == tf.int32 and grayscale values in the rage [0, 1].
mask = tf.read_file(single_mask_filename_tensor)
mask = tf.image.decode_gif(mask)
# After decoding, mask.shape == [1, 1280, 1918, 3]
# This is because 'gif' files are assumed to contain several 'frames'.
mask = tf.image.rgb_to_grayscale(mask)
# After grayscaling, mask.shape == [1, 1280, 1918, 1]
mask = tf.image.pad_to_bounding_box( # Works with 4-D tensors
mask,
offset_height=top_padding,
offset_width=left_padding,
target_height=target_image_shape[0],
target_width=target_image_shape[1])
mask = tf.reshape(mask, target_image_shape)
# After reshaping, mask.shape == [1280, 1920]
mask = tf.cast(mask / 255, tf.int32)
return image, mask
def input_fn(
image_filenames, # List of image filenames
mask_filenames, # List of mask filenames
training, # Determines whether or not images should be shuffled
data_format, # Passed to the parser so that RGB images can be reshaped
num_repeats=1, # Number of times to repeat the set of images
batch_size=1, # Number of images in each batch
num_parallel_batches=8, # How many processing cores are available?
num_prefetch=None): # Number of images to pretech (None: let TF decide)
image_filename_tensor = tf.data.Dataset.from_tensor_slices(image_filenames)
mask_filename_tensor = tf.data.Dataset.from_tensor_slices(mask_filenames)
examples = tf.data.Dataset.zip(
(image_filename_tensor, mask_filename_tensor))
if training:
examples = examples.apply(tf.contrib.data.shuffle_and_repeat(
len(image_filenames) * num_repeats, # buffer size
num_repeats))
examples = examples.apply(tf.contrib.data.map_and_batch(
lambda image, mask: parser(image, mask, data_format),
batch_size=batch_size,
num_parallel_batches=num_parallel_batches))
examples = examples.prefetch(num_prefetch)
return examples
## Old versions:
## These versions have been kept only for reference. They read data from TFRecord files.
# def parser(record, image_shape, data_format):
# '''
# Defines how to convert each record in a TFRecords dataset into
# its original form.
#
# For our network, each record contains an original image with shape
# (n, n, 3) and a segmented image (mask) of shape (n, n).
#
# Returns a tuple: (image, mask)
# '''
#
# keys_to_features = {
# "image": tf.FixedLenFeature([], tf.string),
# "mask": tf.FixedLenFeature([], tf.string)
# }
# parsed = tf.parse_single_example(record, keys_to_features)
# original = tf.decode_raw(parsed["image"], tf.uint8)
# original = tf.cast(original, tf.float32)
#
# original = tf.reshape(original, shape=[*image_shape, 3])
#
# if data_format == 'channels_first':
# # Channels first should improve performance when training on GPU
# # original = tf.reshape(original, shape=[3, *image_shape])
# red = original[:, :, 0]
# green = original[:, :, 1]
# blue = original[:, :, 2]
# original = tf.stack([red, green, blue], axis=0)
# # TODO: Experiment with reshaping to determine if this is valid.
# # These reshape commands may be causing problems for training. The images were
# # saved in HWC format; reshaping them may destroy the image.
#
# segmented = tf.decode_raw(parsed["mask"], tf.uint8)
# segmented = tf.cast(segmented, tf.int32)
# segmented = tf.reshape(segmented, shape=image_shape)
# return original, segmented
# def input_fn(filename, image_shape, data_format, train, num_repeat=1, batch_size=1):
# # Training Performance: A user's guide to converge faster (TF Dev Summit 2018)
# # https://www.youtube.com/watch?v=SxOsJPaxHME&t=1529s
# dataset = tf.data.TFRecordDataset(
# filenames=filename,
# compression_type="GZIP", # Full resolution images have been compressed.
# num_parallel_reads=8)
#
# if train: # TODO: Create and examine a profile trace
# dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(10000, num_repeat))
# # Does this actually load the data into memory? Should we not store our data
# # in a TFRecord file? We could just point TF to the data directory.
# else:
# dataset = dataset.repeat(num_repeat)
#
# dataset = dataset.apply(tf.contrib.data.map_and_batch(
# lambda record: parser(record, image_shape, data_format),
# batch_size=batch_size,
# num_parallel_batches=8))
#
# dataset = dataset.prefetch(4)
# # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/distribute/README.md
# return dataset # Expected by estimator's `train` method.