-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathtrain.py
executable file
·122 lines (108 loc) · 4.96 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import tensorflow as tf
import ResNet_lib as resnet
from FlowIO import DataSetLib as OF
import os
import argparse as aps
from configobj import ConfigObj
cfg_struct = ConfigObj("Unit.cfg")
cfg_train = cfg_struct['Train data']
cfg_image = cfg_struct['Image data']
parser = aps.ArgumentParser(description="manual to this script")
parser.add_argument("--lr",type=float,default=cfg_train["learning_rate"])
parser.add_argument("--classes",type=int,default=cfg_image["num_classes"])
parser.add_argument("--batch_size",type=int,default=cfg_train["batch_size"])
parser.add_argument("--step",type=int,default=cfg_train["step"])
parser.add_argument("--output",type=str,default="model")
parser.add_argument("--save_num",type=int,default=1000)
parser.add_argument("--log_num",type=int,default=50)
parser.add_argument("--model",type=str,default=None)
args = parser.parse_args()
LEARNING_RATE_BASE = args.lr # 基础学习率
N_CLASSES = args.classes # 分类数目
BATCH_SIZE = args.batch_size # 批大小
IMAGE_SIZE = cfg_image["image_size"] # 图像大小
NUM_CHANNELS = 3 # 图像深度
STEP = args.step # 迭代步数
SAVE_NUM = args.save_num # 保存步长
MODEL_PATH = args.output # 模型保存地址
LOG_NUM = args.log_num # 输出步长
MODEL = args.model # 加载模型继续训练
def losses(logits, labels):
with tf.variable_scope('loss') as scope:
labels = tf.reshape(labels, [-1])
labels = tf.one_hot(labels, depth=logits.get_shape().as_list()[1])
cross_entropy = tf.nn.softmax_cross_entropy_with_logits \
(logits=logits, labels=labels, name='xentropy_per_example')
loss = tf.reduce_mean(cross_entropy, name='loss') + tf.add_n(
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
tf.summary.scalar(scope.name + '/loss', loss)
return loss
def trainning(loss, learning_rate,batch_size,num_data):
with tf.name_scope('optimizer'):
global_step = tf.Variable(0, name='global_step', trainable=False)
learning_rate = tf.train.exponential_decay(
learning_rate,
global_step,
num_data / batch_size, 0.99,
staircase=True)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op
def evaluation(logits, labels):
with tf.variable_scope('accuracy') as scope:
labels = tf.reshape(labels, [-1])
labels = tf.one_hot(labels, depth=logits.get_shape().as_list()[1])
labels = tf.argmax(labels,1)
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits,1), labels), tf.float32))
tf.summary.scalar(scope.name + '/accuracy', accuracy)
return accuracy
def train():
x = tf.placeholder(tf.float32, [
None,IMAGE_SIZE,IMAGE_SIZE,
NUM_CHANNELS],name='x-input')
# 标签
y_ = tf.placeholder(tf.int64, [None,1], name='y-input')
# 是否处于训练状态
is_train = tf.placeholder(tf.bool, [1], name="is_train")
# 获取结果
y = resnet.inference(x, resnet.ResNet_mini_demo['layer_50'], N_CLASSES, is_train[0])
loss = losses(y, y_)
acc = evaluation(y, y_)
train_step = trainning(loss, LEARNING_RATE_BASE, BATCH_SIZE, 1000)
with tf.control_dependencies([train_step]):
train_op = tf.no_op(name='train')
# TensorFlow持久化类。
tf.add_to_collection('pred_network', y)
saver = tf.train.Saver(tf.all_variables(), max_to_keep=50)
with tf.Session() as sess:
# 初始化神经网络
tf.global_variables_initializer().run()
with tf.device("/cpu:0"):
get_flow = OF(sess, "Train", [IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS], BATCH_SIZE)
# 获取训练TenSor
next_batch = get_flow.get_batch_data()
if MODEL is not None:
# 加载模型
saver.restore(sess, MODEL)
tf.summary.image("R.G.B", tf.expand_dims(next_batch[0][0], 0))
merged = tf.summary.merge_all()
log_summary = tf.summary.FileWriter("log_files", sess.graph)
# 迭代训练
for i in range(STEP):
iamges,labels = sess.run(next_batch)
_, loss_value,acc_value,merged_value = sess.run([train_op, loss,acc,merged],
feed_dict={x: iamges, y_: labels,is_train:[True]})
log_summary.add_summary(merged_value,i)
if i % LOG_NUM == 0:
print("After %d training step(s), loss on training batch is %g." % (i, loss_value),"acc : ",acc_value)
# 模型存储
if loss_value is not None and i % SAVE_NUM == 0:
save_path = os.path.join(MODEL_PATH, 'mod.ckpt')
saver.save(sess, save_path, global_step=i)
log_summary.close()
def main(argv=None):
train()
if __name__ == '__main__':
tf.app.run()