forked from calmiLovesAI/TensorFlow2.0_ResNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
prepare_data.py
57 lines (45 loc) · 2.3 KB
/
prepare_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import tensorflow as tf
import config
import pathlib
from config import image_height, image_width, channels
def load_and_preprocess_image(img_path):
# read pictures
img_raw = tf.io.read_file(img_path)
# decode pictures
img_tensor = tf.image.decode_jpeg(img_raw, channels=channels)
# resize
img_tensor = tf.image.resize(img_tensor, [image_height, image_width])
img_tensor = tf.cast(img_tensor, tf.float32)
# normalization
img = img_tensor / 255.0
return img
def get_images_and_labels(data_root_dir):
# get all images' paths (format: string)
data_root = pathlib.Path(data_root_dir)
all_image_path = [str(path) for path in list(data_root.glob('*/*'))]
# get labels' names
label_names = sorted(item.name for item in data_root.glob('*/'))
# dict: {label : index}
label_to_index = dict((label, index) for index, label in enumerate(label_names))
# get all images' labels
all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]
return all_image_path, all_image_label
def get_dataset(dataset_root_dir):
all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)
# print("image_path: {}".format(all_image_path[:]))
# print("image_label: {}".format(all_image_label[:]))
# load the dataset and preprocess images
image_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
dataset = tf.data.Dataset.zip((image_dataset, label_dataset))
image_count = len(all_image_path)
return dataset, image_count
def generate_datasets():
train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)
valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)
test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)
# read the original_dataset in the form of batch
train_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)
valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)
test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)
return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count