forked from duchenpaul/cat_dog_classify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_prep.py
128 lines (105 loc) · 3.64 KB
/
data_prep.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import matplotlib.pyplot as plt
from keras.utils import np_utils
import numpy as np
import os
from tqdm import tqdm
import config
import toolkit_file
import image_process
from pprint import pprint
dataset_dir_1 = config.DATASET_DIR_1
dataset_dir_2 = config.DATASET_DIR_2
data_dump = config.DATA_DMP
def generate_file_list_1(dataset_dir):
imgFileList = [x for x in toolkit_file.get_file_list(
dataset_dir) if x.endswith('.jpg')]
dataset_dict_list = []
for file in imgFileList:
pic_id = toolkit_file.get_basename(
file, withExtension=False)
# ['Cat', 'Dog']
if pic_id.lower().startswith('cat'):
group_id = 0
elif pic_id.lower().startswith('dog'):
group_id = 1
else:
break
dataset_dict_list.append(
{'pic_id': pic_id, 'group_id': group_id, 'image_path': file})
return dataset_dict_list
def generate_file_list_2(dataset_dir):
imgFileList = [x for x in toolkit_file.get_file_list(
dataset_dir) if x.endswith('.jpg')]
dataset_dict_list = []
for file in imgFileList:
pic_id = toolkit_file.get_basename(
file, withExtension=False)
# ['Cat', 'Dog']
if file.split(os.sep)[-2].lower() == 'cat':
group_id = 0
dataset_dict_list.append(
{'pic_id': pic_id, 'group_id': group_id, 'image_path': file})
elif file.split(os.sep)[-2].lower() == 'dog':
group_id = 1
dataset_dict_list.append(
{'pic_id': pic_id, 'group_id': group_id, 'image_path': file})
else:
break
return dataset_dict_list
def generate_file_list():
dataset_dict_list = []
dataset_dict_list = generate_file_list_1(dataset_dir_1)
# Remove the second source feed to reduce system source consuming
dataset_dict_list += generate_file_list_2(dataset_dir_2)
cat_count = 0
dog_count = 0
for x in dataset_dict_list:
if x['group_id'] == 0:
cat_count += 1
elif x['group_id'] == 1 :
dog_count += 1
else:
raise
print('Cat count: {}'.format(cat_count))
print('Dog count: {}'.format(dog_count))
return dataset_dict_list
def read_img(image_path_list):
# x_dataset = np.array([image_process.image_process(x)
# for x in image_path_list])
x_dataset = []
for x in tqdm(image_path_list):
try:
x_dataset.append(image_process.image_process(x))
except Exception as e:
print('Error processing: {}'.format(x))
print(e)
else:
pass
x_dataset = np.array(x_dataset) / 255
# x_dataset = np_utils.normalize(x_dataset)
return x_dataset
def dump_dataset(x_dataset, y_dataset):
dataset = []
for x in tqdm(range(len(y_dataset))):
img_data = x_dataset[x]
label = y_dataset[x]
dataset.append((img_data, label))
print('Convert to array...')
dataset = np.array(dataset)
print('Shuffle array...')
np.random.shuffle(dataset)
print('Dump array...')
np.save(data_dump, dataset)
if __name__ == '__main__':
print('Processing...')
dataset_dict_list = generate_file_list()
print('Reading image...')
x_dataset = read_img([x['image_path'] for x in dataset_dict_list])
y_dataset = np_utils.to_categorical(
[x['group_id'] for x in dataset_dict_list])
print('dumping numpy...')
print('x_dataset: {}'.format(x_dataset.shape))
print('y_dataset: {}'.format(y_dataset.shape))
dump_dataset(x_dataset, y_dataset)