-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain_test_split.py
53 lines (43 loc) · 2.31 KB
/
train_test_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import shutil
import numpy as np
def split_dataset_into_test_and_train_sets(all_data_dir, training_data_dir, testing_data_dir, testing_data_pct):
# Recreate testing and training directories
if testing_data_dir.count('/') > 1:
shutil.rmtree(testing_data_dir, ignore_errors=False)
os.makedirs(testing_data_dir)
print("Successfully cleaned directory " + testing_data_dir)
else:
print("Refusing to delete testing data directory " + testing_data_dir + " as we prevent you from doing stupid things!")
if training_data_dir.count('/') > 1:
shutil.rmtree(training_data_dir, ignore_errors=False)
os.makedirs(training_data_dir)
print("Successfully cleaned directory " + training_data_dir)
else:
print("Refusing to delete testing data directory " + training_data_dir + " as we prevent you from doing stupid things!")
num_training_files = 0
num_testing_files = 0
for subdir, dirs, files in os.walk(all_data_dir):
category_name = os.path.basename(subdir)
training_data_category_dir = training_data_dir + '/' + category_name
testing_data_category_dir = testing_data_dir + '/' + category_name
if not os.path.exists(training_data_category_dir):
os.mkdir(training_data_category_dir)
if not os.path.exists(testing_data_category_dir):
os.mkdir(testing_data_category_dir)
for file in files:
input_file = os.path.join(subdir, file)
if np.random.rand(1) < testing_data_pct:
shutil.copy(input_file, testing_data_dir + '/' + category_name + '/' + file)
num_testing_files += 1
else:
shutil.copy(input_file, training_data_dir + '/' + category_name + '/' + file)
num_training_files += 1
print("Processed " + str(num_training_files) + " training files.")
print("Processed " + str(num_testing_files) + " testing files.")
datapath = '/notebooks/'
all_data_dir = os.path.join(datapath, "imgs_1m/")
training_data_dir = os.path.join(datapath, "training")
testing_data_dir = os.path.join(datapath, "testing")
testing_data_pct = 0.2
split_dataset_into_test_and_train_sets(all_data_dir=all_data_dir, training_data_dir=training_data_dir, testing_data_dir=testing_data_dir, testing_data_pct=testing_data_pct)