-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_imagenet.py
109 lines (81 loc) · 2.98 KB
/
train_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import skimage.io # bug. need to import this before tensorflow
import skimage.transform # bug. need to import this before tensorflow
from resnet_train import train
import tensorflow as tf
import time
import os
import sys
import re
import numpy as np
from resnet import inference
from synset import *
from image_processing import image_preprocessing
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', '/search/odin/gongzhenting/work/image/data/ILSVRC/ILSVRC2012_img_train/',
'imagenet dir')
def file_list(data_dir):
dir_txt = data_dir + "train.txt"
filenames = []
with open(dir_txt, 'r') as f:
for line in f:
if line[0] == '.': continue
line = line.rstrip()
fn = os.path.join(data_dir, line.split(" ")[0])
filenames.append(fn)
return filenames
def load_data(data_dir):
data = []
i = 0
print "listing files in", data_dir
start_time = time.time()
files = file_list(data_dir)
duration = time.time() - start_time
print "took %f sec" % duration
for img_fn in files:
ext = os.path.splitext(img_fn)[1]
if ext != '.JPEG': continue
label_name = re.search(r'(n\d+)', img_fn).group(1)
fn = os.path.join(data_dir, img_fn)
label_index = synset_map[label_name]["index"]
data.append({
"filename": fn,
"label_name": label_name,
"label_index": label_index,
"desc": synset[label_index],
})
return data
def distorted_inputs():
data = load_data(FLAGS.data_dir)
filenames = [ d['filename'] for d in data ]
label_indexes = [ d['label_index'] for d in data ]
filename, label_index = tf.train.slice_input_producer([filenames, label_indexes], shuffle=True)
num_preprocess_threads = 4
images_and_labels = []
for thread_id in range(num_preprocess_threads):
print ("filename:",filename)
image_buffer = tf.read_file(filename)
bbox = []
train = True
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=FLAGS.batch_size,
capacity=2 * num_preprocess_threads * FLAGS.batch_size)
height = FLAGS.input_size
width = FLAGS.input_size
depth = 3
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[FLAGS.batch_size, height, width, depth])
return images, tf.reshape(label_index_batch, [FLAGS.batch_size])
def main(_):
images, labels = distorted_inputs()
is_training = tf.placeholder('bool', [], name='is_training')
logits = inference(images,
num_classes=1000,
is_training=True,
bottleneck=False,
num_blocks=[2, 2, 2, 2])
train(is_training, logits, images, labels)
if __name__ == '__main__':
tf.app.run()