-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_resnet.py
130 lines (110 loc) · 6.16 KB
/
train_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import resnet20,getData,os.path
import tensorflow as tf
from datetime import datetime
import numpy as np
with tf.device('/gpu:0'):
flags = tf.flags
FLAGS = flags.FLAGS
# flags.DEFINE_float('learning_rate', 0.01, 'Learning rate for the training.')
flags.DEFINE_integer('max_epoches', 200, 'Number of epoches to run trainer.')
flags.DEFINE_integer('batch_size', 128,
'Batch size. Must divide dataset sizes without remainder.')
flags.DEFINE_string('train_dir', 'res_logs',
'Directory to put the training data.')
learning_rate = 0.1
# Put logs for each run in separate directory
train_logdir = FLAGS.train_dir + '/' + datetime.now().strftime('%Y%m%d-%H%M%S') + '/train/'
test_logdir = FLAGS.train_dir + '/' + datetime.now().strftime('%Y%m%d-%H%M%S') + '/test/'
# Define input placeholders
images_placeholder = tf.placeholder(tf.float32, shape=(None, 32, 32, 3),name='images')
labels_placeholder = tf.placeholder(tf.int64, shape=None, name='image-labels')
keeprob_placeholder = tf.placeholder(tf.float32, shape=None, name='keep_prob')
isTrain_placeholder = tf.placeholder(tf.bool, name='phase_train')
# Operation for the classifier's result
logits = resnet20.residual_net(images_placeholder, keeprob_placeholder,isTrain_placeholder)
# Operation for the loss function
loss = resnet20.loss(logits, labels_placeholder)
# Create a variable to track the global step
global_step = tf.Variable(0, name='global_step', trainable=False)
# Operation for the training step
train_step = resnet20.training(loss, learning_rate, global_step)
# Operation calculating the accuracy of our predictions
accuracy = resnet20.evaluation(logits, labels_placeholder)
# Operation merging summary data for TensorBoard
merged = tf.summary.merge_all()
# Define saver to save model state at checkpoints
saver = tf.train.Saver()
# Load CIFAR-10 data
data_sets = getData.load_cifar10()
# -----------------------------------------------------------------------------
# Run the TensorFlow graph
# -----------------------------------------------------------------------------
config=tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
# with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter(train_logdir, sess.graph)
test_writer = tf.summary.FileWriter(test_logdir)
# Generate input data batches
zipped_data = zip(data_sets['train_data'], data_sets['train_label'])
batches = getData.gen_batch(list(zipped_data), FLAGS.batch_size,FLAGS.max_epoches)
for i in range(FLAGS.max_epoches):
for j in range(data_sets['train_data'].shape[0]//FLAGS.batch_size):
# Get next input data batch
batch = next(batches)
images_batch, labels_batch = zip(*batch)
feed_dict = {
images_placeholder: images_batch,
labels_placeholder: labels_batch,
keeprob_placeholder: 0.5,
isTrain_placeholder: True
}
# train_loss = sess.run(loss, feed_dict=feed_dict)
# print('Step {:d}, loss {:g}'.format(j, train_loss))
# Perform a single training step
_, train_loss = sess.run([train_step, loss], feed_dict=feed_dict)
# Periodically save checkpoint
# Periodically print out the model's current accuracy
summary, train_accuracy , train_loss= sess.run([merged,accuracy,loss], feed_dict=feed_dict)
print('Epoch {:d}, training accuracy {:g}'.format(i, train_accuracy))
print('Epoch {:d}, training loss {:g}'.format(i, train_loss))
train_summary = tf.Summary(value=[tf.Summary.Value(tag="accuracy", simple_value=train_accuracy), tf.Summary.Value(tag="loss", simple_value=train_loss)])
train_writer.add_summary(summary, i)
train_writer.add_summary(train_summary,i)
acc = []
test_losses = []
for k in range(data_sets['test_data'].shape[0]//FLAGS.batch_size):
test_feed_dic = {
images_placeholder: data_sets['test_data'][i*FLAGS.batch_size:(i+1)*FLAGS.batch_size],
labels_placeholder: data_sets['test_label'][i*FLAGS.batch_size:(i+1)*FLAGS.batch_size],
keeprob_placeholder: 0.5,
isTrain_placeholder: False
}
test_loss, test_accuracy = sess.run([loss, accuracy], feed_dict=test_feed_dic)
acc.append(test_accuracy)
test_losses.append(test_loss)
avg_acc = float(np.mean(np.asarray(acc)))
avg_loss = float(np.mean(np.asarray(test_losses)))
test_summary = tf.Summary(value=[tf.Summary.Value(tag="accuracy", simple_value=avg_acc), tf.Summary.Value(tag="loss", simple_value=avg_loss)])
test_writer.add_summary(test_summary, i)
print('Test accuracy {:g}'.format(avg_acc))
print('Test loss {:g}'.format(avg_loss))
# if i == 1:
# learning_rate *= 0.1
# train_step = resnet20.training(loss, learning_rate, global_step)
if (i+1) == 5:
learning_rate *= 0.1
train_step = resnet20.training(loss, learning_rate, global_step)
if (i+1) == 10:
learning_rate *= 0.1
train_step = resnet20.training(loss, learning_rate, global_step)
if (i+1) == 20:
learning_rate *= 0.1
train_step = resnet20.training(loss, learning_rate, global_step)
# if (i + 1) % 20 == 0:
# checkpoint_file = os.path.join(FLAGS.train_dir, 'checkpoint')
# saver.save(sess, checkpoint_file, global_step=i)
# print('Saved checkpoint')
# learning_rate = learning_rate*0.8
# train_step = resnet20.training(loss, learning_rate, global_step)