-
Notifications
You must be signed in to change notification settings - Fork 5
/
dl_utils.py
335 lines (270 loc) · 10.9 KB
/
dl_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
import cv2
import numpy as np
from data_handler import DataHandler
import matplotlib.pyplot as plt
from nengo.utils.matplotlib import rasterplot
def preprocess_images(
image_data, res, show_resized_image=False, flatten=True, normalize=True
):
"""
Accepts a 3D array (single image) or 4D array (multiple images) of shape
(n_imgs, vertical_pixels, horizontal_pixels, rgb_data).
Returns the image array with the following Optional processing
- reshaped to the specified res if not matching
- normalizing image to be from 0-1 instead of 0-255
- flattening of images
If flattening: returns image array of shape (n_images, n_subpixels)
else: returns image array of shape (n_imgs, horizontal_pixels, vertical_pixels, rgb_data).
Note that the printouts will only be output for the first image
Parameters
----------
image_data: 3D or 4D array of floats
a single, or array of rgb image(s)
shape (n_imgs, vertical pixels, horizontal pixels, 3)
res: list of 2 floats
the desired resolution of the output images
show_resized_image: boolean, Optional (Default: False)
plots the image before and after rescaling
flatten: boolean, Optional (Default: True)
flattens the image to a vector
if array of images passed in, will output (n_images, subpixels) shape
normalize: boolean, Optional (Default: True)
normalize data from 0-255 to 0-1
"""
scaled_image_data = []
# single image, append 1 dimension so we can loop through the same way
image_data = np.asarray(image_data)
shape = image_data.shape
if image_data.ndim == 3:
image_data = image_data.reshape((1, shape[0], shape[1], shape[2]))
# expect rgb image data
assert image_data.shape[3] == 3
for count, data in enumerate(image_data):
rgb = np.asarray(data)
# normalize
if normalize:
if np.mean(data) > 1:
if count == 0: # only print for the first image
print("Image passed in 0-255, normalizing to 0-1")
rgb = rgb / 255
else:
if count == 0: # only print for the first image
print("Image passed in 0-1, skipping normalizing")
# resize image resolution
if shape[1] != res[0] or shape[2] != res[1]:
if count == 0: # only print for the first image
print("Resolution does not match desired value, resizing...")
print("Desired Res: ", res)
print("Input Res: ", [shape[1], shape[2]])
rgb = cv2.resize(rgb, dsize=(res[1], res[0]), interpolation=cv2.INTER_CUBIC)
else:
if count == 0: # only print for the first image
print("Resolution already at desired value, skipping resizing...")
# visualize scaling for debugging
if show_resized_image:
plt.Figure()
a = plt.subplot(121)
a.set_title("Original")
a.imshow(data, origin="lower")
b = plt.subplot(122)
b.set_title("Scaled")
b.imshow(rgb, origin="lower")
plt.show()
# flatten to 1D
if flatten:
rgb = np.ravel(rgb)
scaled_image_data.append(np.copy(rgb))
scaled_image_data = np.asarray(scaled_image_data)
return scaled_image_data
def repeat_data(data, batch_data=False, n_steps=1):
"""
Accepts flattened data of shape (number images / targets, flattened data length)
Repeats the data n_steps times and batches the images based on batch_size
Parameters
----------
data: array of floats
inputs data of shape (number of imgs / targets, flattened data dimensionality)
batch_data: boolean, Optional (Default: False)
True: output shape (number imgs / targets, n_steps, flattened dimensionality)
False: output shape (1, number imgs / targets * n_steps, flattened dimensionality)
n_steps: int, Optional (Default: 1)
number of times to repeat each input target / image
"""
print("Data pre_tile: ", data.shape)
if batch_data:
# batch our images for training
data = np.tile(data[:, None, :], (1, n_steps, 1))
else:
# run like nengo sim.run without batching
data = np.repeat(data, n_steps, 0)
data = data[np.newaxis, :]
print("Data post_tile: ", data.shape)
return data
def load_data(
db_name, label="training_0000", n_imgs=None, thresh=1e5, step_size=1, db_dir=None
):
"""
loads rgb images and targets from an hdf5 database and returns them as a np array
Expects data to be saved in the following group stucture:
Training Data
training_0000/data/0000 using %04d to increment data name
Validation Data
validation_0000/data/0000 using %04d to increment data name
Both return an array with the rgb image saved under the 'rgb' key
and the target saved under the 'target' key
Parameters
----------
db_name: string
name of database to load from
label: string, Optional (Default: 'training_0000')
location in database to load from
n_imgs: int
how many images to load
"""
# TODO: specify the data format expected in the comment above
dat = DataHandler(db_dir=db_dir, db_name=db_name)
# load training images
images = []
targets = []
skip_list = ["datestamp", "timestamp"]
keys = np.array(
[int(val) for val in dat.get_keys("%s" % label) if val not in skip_list]
)
n_imgs = max(keys) if n_imgs is None else n_imgs
print("Total number of images in dataset: ", max(keys))
for nn in range(0, n_imgs, step_size):
data = dat.load(
parameters=["rgb", "target"], save_location="%s/%04d" % (label, nn)
)
if np.linalg.norm(data["target"]) < thresh:
images.append(data["rgb"])
targets.append(data["target"])
images = np.asarray(images)
targets = np.asarray(targets)
print("Total number of images within threshold: ", images.shape[0])
return images, targets
def plot_data(db_name, label="training_0000", n_imgs=None, db_dir=None):
"""
loads rgb images and targets from an hdf5 database and plots the images, prints
the targets
Expects data to be saved in the following group stucture:
Training Data
training_0000/data/0000 using %04d to increment data name
Validation Data
validation_0000/data/0000 using %04d to increment data name
Both return an array with the rgb image saved under the 'rgb' key
and the target saved under the 'target' key
Parameters
----------
db_name: string
name of database to load from
label: string, Optional (Default: 'training_0000')
location in database to load from
n_imgs: int
how many images to load
"""
# TODO: specify the data format expected in the comment above
dat = DataHandler(db_name=db_name, db_dir=db_dir)
keys = np.array([int(val) for val in dat.get_keys("%s/data" % label)])
print("Total number of images in dataset: ", max(keys))
for nn in range(n_imgs):
data = dat.load(
parameters=["rgb", "target"], save_location="%s/data/%04d" % (label, nn)
)
print("Target: ", data["target"])
plt.figure()
a = plt.subplot(1, 1, 1)
a.imshow(data["rgb"] / 255)
plt.show()
def plot_prediction_error(
predictions,
target_vals,
save_folder=".",
save_name="prediction_results",
show_plot=False,
):
"""
Accepts predictions and targets, plots the x and y error, along with the target location
Parameters
----------
predictions: array of floats
nengo sim data[output] array
target_vals: array of float
flattened target data that was passed in during inferece
save_folder: string
location to save figures
save_name: string, Optional (Default: 'prediction_results')
name to save plot under
"""
print("targets shape: ", np.asarray(target_vals).shape)
print("prediction shape: ", np.asarray(predictions).shape)
if predictions.ndim > 2:
shape = np.asarray(predictions).shape
predictions = np.asarray(predictions).reshape(shape[0] * shape[1], shape[2])
print("pred reshape: ", predictions.shape)
if target_vals.ndim > 2:
shape = np.asarray(target_vals).shape
target_vals = np.asarray(target_vals).reshape(shape[0] * shape[1], shape[2])
print("targets reshape: ", target_vals.shape)
# calculate our error to target val
x_err = np.linalg.norm(target_vals[:, 0] - predictions[:, 0])
y_err = np.linalg.norm(target_vals[:, 1] - predictions[:, 1])
fig = plt.Figure()
x = np.arange(predictions.shape[0])
# plot our X predictions
plt.subplot(311)
plt.title("X: %.3f" % x_err)
plt.plot(x, predictions[:, 0], label="predictions", color="k", linestyle="--")
plt.plot(x, target_vals[:, 0], label="target", color="r")
plt.legend()
# plot our Y predictions
plt.subplot(312)
plt.title("Y: %.3f" % y_err)
plt.plot(x, predictions[:, 1], label="predictions", color="k", linestyle="--")
plt.plot(x, target_vals[:, 1], label="target", color="r")
plt.legend()
# plot our targets in the xy plane to get an idea of their coverage range
plt.subplot(313)
plt.title("Target XY")
plt.xlim([-3, 3])
plt.ylim([-3, 3])
plt.scatter(0, 0, color="r", label="rover", s=2)
plt.scatter(target_vals[:, 0], target_vals[:, 1], label="target", s=1)
plt.gca().set_aspect("equal")
plt.tight_layout()
plt.savefig("%s/%s.png" % (save_folder, save_name))
print("Saving prediction results to %s/%s.png" % (save_folder, save_name))
if show_plot:
plt.show()
plt.close()
def consolidate_data(db_name, label_list, thresh=3.5, step_size=1, db_dir=None):
"""
loads rgb images and targets from multiple hdf5 database and consolidates them
into a single np array, saves back to the database under the specified label
Parameters
----------
db_name: string
name of database to load from
label_list: list
list of locations in database to load from
"""
dat = DataHandler(db_dir=db_dir, db_name=db_name)
all_images = []
all_targets = []
for ii, label in enumerate(label_list):
print("db_name: ", db_name)
print("label: ", label)
images, targets = load_data(
db_name=db_name,
db_dir=db_dir,
label=label,
thresh=thresh,
step_size=step_size,
)
all_images.append(images)
all_targets.append(targets)
all_images = np.vstack(all_images)
all_targets = np.vstack(all_targets)
print("Total images shape: ", all_images.shape)
print("Total targets shape: ", all_targets.shape)
return all_images, all_targets