-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist_dataset_generator.py
58 lines (44 loc) · 2.47 KB
/
mnist_dataset_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import sys
import tensorflow as tf
from model.filesystem import persist_dataset
MNIST_class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
MNIST_IMAGE_SIZE = (28, 28, 1)
MNIST_LABELS = 10
def generate_MNIST_dataset(dataset_folder):
train_dataset, validation_dataset, test_dataset = get_MNIST_dataset()
print("Persisting datasets")
persist_dataset(train_dataset, f'{dataset_folder}/train')
persist_dataset(validation_dataset, f'{dataset_folder}/validation')
persist_dataset(test_dataset, f'{dataset_folder}/test')
def get_MNIST_dataset(train_per=0.8, number_training_examples=-1, number_test_examples=-1):
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
# Add a channels dimension
train_images = train_images[..., tf.newaxis].astype("float32")
test_images = test_images[..., tf.newaxis].astype("float32")
if number_training_examples == -1:
number_training_examples = train_images.shape[0]
if number_test_examples == -1:
number_test_examples = test_images.shape[0]
train_images, test_images = train_images / 255.0, test_images / 255.0
train_dataset, val_dataset = _generate_train_and_val_datasets(train_images[:number_training_examples],
train_labels[:number_training_examples],
train_per=train_per)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images[:number_test_examples],
test_labels[:number_test_examples]))
return train_dataset, val_dataset, test_dataset
def _generate_train_and_val_datasets(train_images, train_labels, train_per=0.8):
train_size = round(train_per * train_images.shape[0])
full_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
train_dataset = full_dataset.take(train_size)
val_dataset = full_dataset.skip(train_size)
return train_dataset, val_dataset
if __name__ == '__main__':
if len(sys.argv) < 2:
print('You have to introduce the path to the folder containing the dataset generated by the script'
' mnist_dataset_generator.py')
else:
dataset_filepath = os.path.join(sys.argv[1])
generate_MNIST_dataset(dataset_filepath)