Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Byronnar authored Nov 7, 2019
1 parent 86803e1 commit f20dd93
Showing 1 changed file with 73 additions and 71 deletions.
144 changes: 73 additions & 71 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,57 @@
from core.yolov3 import YOLOV3
from core.config import cfg


class YoloTrain(object):
def __init__(self): # 从config文件
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
self.num_classes = len(self.classes)
self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
def __init__(self): # 从config文件获取到一些变量
self.anchor_per_scale = cfg.YOLO.ANCHOR_PER_SCALE
self.classes = utils.read_class_names(cfg.YOLO.CLASSES)
self.num_classes = len(self.classes)
self.learn_rate_init = cfg.TRAIN.LEARN_RATE_INIT
self.learn_rate_end = cfg.TRAIN.LEARN_RATE_END
self.first_stage_epochs = cfg.TRAIN.FISRT_STAGE_EPOCHS
self.second_stage_epochs = cfg.TRAIN.SECOND_STAGE_EPOCHS
self.warmup_periods = cfg.TRAIN.WARMUP_EPOCHS
self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT
self.time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY
self.max_bbox_per_scale = 150
self.train_logdir = "./data/log/train"
self.trainset = Dataset('train')
self.testset = Dataset('test')
self.steps_per_period = len(self.trainset)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

with tf.name_scope('define_input'):
self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')
self.label_sbbox = tf.placeholder(dtype=tf.float32, name='label_sbbox')
self.label_mbbox = tf.placeholder(dtype=tf.float32, name='label_mbbox')
self.label_lbbox = tf.placeholder(dtype=tf.float32, name='label_lbbox')
self.warmup_periods = cfg.TRAIN.WARMUP_EPOCHS
self.initial_weight = cfg.TRAIN.INITIAL_WEIGHT
self.time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
self.moving_ave_decay = cfg.YOLO.MOVING_AVE_DECAY
self.max_bbox_per_scale = 150
self.train_logdir = "./data/log/train" # 日志保存地址
self.trainset = Dataset('train')
self.testset = Dataset('test')
self.steps_per_period = len(self.trainset)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))

with tf.name_scope('define_input'): # 定义输入层
self.input_data = tf.placeholder(dtype=tf.float32, name='input_data')
self.label_sbbox = tf.placeholder(dtype=tf.float32, name='label_sbbox')
self.label_mbbox = tf.placeholder(dtype=tf.float32, name='label_mbbox')
self.label_lbbox = tf.placeholder(dtype=tf.float32, name='label_lbbox')
self.true_sbboxes = tf.placeholder(dtype=tf.float32, name='sbboxes')
self.true_mbboxes = tf.placeholder(dtype=tf.float32, name='mbboxes')
self.true_lbboxes = tf.placeholder(dtype=tf.float32, name='lbboxes')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')

with tf.name_scope("define_loss"):
with tf.name_scope("define_loss"): # 定义损失函数
self.model = YOLOV3(self.input_data, self.trainable)
self.net_var = tf.global_variables()
self.giou_loss, self.conf_loss, self.prob_loss = self.model.compute_loss(
self.label_sbbox, self.label_mbbox, self.label_lbbox,
self.true_sbboxes, self.true_mbboxes, self.true_lbboxes)
self.label_sbbox, self.label_mbbox, self.label_lbbox,
self.true_sbboxes, self.true_mbboxes, self.true_lbboxes)
self.loss = self.giou_loss + self.conf_loss + self.prob_loss

with tf.name_scope('learn_rate'):
with tf.name_scope('learn_rate'): # 定义学习率
self.global_step = tf.Variable(1.0, dtype=tf.float64, trainable=False, name='global_step')
warmup_steps = tf.constant(self.warmup_periods * self.steps_per_period,
dtype=tf.float64, name='warmup_steps')
train_steps = tf.constant( (self.first_stage_epochs + self.second_stage_epochs)* self.steps_per_period,
dtype=tf.float64, name='train_steps')
dtype=tf.float64, name='warmup_steps')
train_steps = tf.constant((self.first_stage_epochs + self.second_stage_epochs) * self.steps_per_period,
dtype=tf.float64, name='train_steps')
self.learn_rate = tf.cond(
pred=self.global_step < warmup_steps,
true_fn=lambda: self.global_step / warmup_steps * self.learn_rate_init,
false_fn=lambda: self.learn_rate_end + 0.5 * (self.learn_rate_init - self.learn_rate_end) *
(1 + tf.cos(
(self.global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi))
(1 + tf.cos(
(self.global_step - warmup_steps) / (train_steps - warmup_steps) * np.pi))
)
global_step_update = tf.assign_add(self.global_step, 1.0)

Expand All @@ -71,16 +72,17 @@ def __init__(self): # 从config文件
但是这又使得训练速度变慢了。因此,采用逐渐增大的学习率,从而达到既可以尽量避免出现nan,又可以等训练过程稳定了再增大训练速度的目的。
'''

with tf.name_scope("define_weight_decay"):
with tf.name_scope("define_weight_decay"): # 指数平滑,可以让算法在最后不那么震荡,结果更有鲁棒性
moving_ave = tf.train.ExponentialMovingAverage(self.moving_ave_decay).apply(tf.trainable_variables())

# 指定需要恢复的参数。层等信息, 位置提前,减少模型体积。
with tf.name_scope('loader_and_saver'):
variables_to_restore = [v for v in self.net_var if v.name.split('/')[0] not in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']]
variables_to_restore = [v for v in self.net_var if
v.name.split('/')[0] not in ['conv_sbbox', 'conv_mbbox', 'conv_lbbox']]
self.loader = tf.train.Saver(variables_to_restore)
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=1)

with tf.name_scope("define_first_stage_train"):
with tf.name_scope("define_first_stage_train"): # 第一阶段训练,只训练指定层
self.first_stage_trainable_var_list = []
for var in tf.trainable_variables():
var_name = var.op.name
Expand All @@ -89,35 +91,34 @@ def __init__(self): # 从config文件
self.first_stage_trainable_var_list.append(var)

first_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=self.first_stage_trainable_var_list)
var_list=self.first_stage_trainable_var_list)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([first_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_frozen_variables = tf.no_op()

with tf.name_scope("define_second_stage_train"):
with tf.name_scope("define_second_stage_train"): # 第二阶段训练,释放所有层
second_stage_trainable_var_list = tf.trainable_variables()
second_stage_optimizer = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss,
var_list=second_stage_trainable_var_list)
var_list=second_stage_trainable_var_list)

with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
with tf.control_dependencies([second_stage_optimizer, global_step_update]):
with tf.control_dependencies([moving_ave]):
self.train_op_with_all_variables = tf.no_op()

with tf.name_scope('summary'):
tf.summary.scalar("learn_rate", self.learn_rate)
tf.summary.scalar("giou_loss", self.giou_loss)
tf.summary.scalar("conf_loss", self.conf_loss)
tf.summary.scalar("prob_loss", self.prob_loss)
tf.summary.scalar("learn_rate", self.learn_rate)
tf.summary.scalar("giou_loss", self.giou_loss)
tf.summary.scalar("conf_loss", self.conf_loss)
tf.summary.scalar("prob_loss", self.prob_loss)
tf.summary.scalar("total_loss", self.loss)

logdir = "./data/log/"
logdir = "./data/log/" # 日志保存地址
if os.path.exists(logdir): shutil.rmtree(logdir)
os.mkdir(logdir)
self.write_op = tf.summary.merge_all()
self.summary_writer = tf.summary.FileWriter(logdir, graph=self.sess.graph)

self.summary_writer = tf.summary.FileWriter(logdir, graph=self.sess.graph)

def train(self):
self.sess.run(tf.global_variables_initializer())
Expand All @@ -130,43 +131,43 @@ def train(self):
self.first_stage_epochs = 0

# 阶段学习率
for epoch in range(1, 1+self.first_stage_epochs+self.second_stage_epochs):
for epoch in range(1, 1 + self.first_stage_epochs + self.second_stage_epochs):
if epoch <= self.first_stage_epochs:
train_op = self.train_op_with_frozen_variables
else:
train_op = self.train_op_with_all_variables

# tqdm is a visualization tool that displays an Iterable object in a progree bar
pbar = tqdm(self.trainset)
train_epoch_loss, test_epoch_loss = [], []

for train_data in pbar:
_, summary, train_step_loss, global_step_val = self.sess.run(
[train_op, self.write_op, self.loss, self.global_step],feed_dict={
self.input_data: train_data[0],
self.label_sbbox: train_data[1],
self.label_mbbox: train_data[2],
self.label_lbbox: train_data[3],
self.true_sbboxes: train_data[4],
self.true_mbboxes: train_data[5],
self.true_lbboxes: train_data[6],
self.trainable: True,
})
[train_op, self.write_op, self.loss, self.global_step], feed_dict={
self.input_data: train_data[0],
self.label_sbbox: train_data[1],
self.label_mbbox: train_data[2],
self.label_lbbox: train_data[3],
self.true_sbboxes: train_data[4],
self.true_mbboxes: train_data[5],
self.true_lbboxes: train_data[6],
self.trainable: True,
})

train_epoch_loss.append(train_step_loss)
self.summary_writer.add_summary(summary, global_step_val)
pbar.set_description("train loss: %.2f" %train_step_loss)
pbar.set_description("train loss: %.2f" % train_step_loss)

for test_data in self.testset:
test_step_loss = self.sess.run( self.loss, feed_dict={
self.input_data: test_data[0],
self.label_sbbox: test_data[1],
self.label_mbbox: test_data[2],
self.label_lbbox: test_data[3],
self.true_sbboxes: test_data[4],
self.true_mbboxes: test_data[5],
self.true_lbboxes: test_data[6],
self.trainable: False,
test_step_loss = self.sess.run(self.loss, feed_dict={
self.input_data: test_data[0],
self.label_sbbox: test_data[1],
self.label_mbbox: test_data[2],
self.label_lbbox: test_data[3],
self.true_sbboxes: test_data[4],
self.true_mbboxes: test_data[5],
self.true_lbboxes: test_data[6],
self.trainable: False,
})

test_epoch_loss.append(test_step_loss)
Expand All @@ -175,7 +176,8 @@ def train(self):
ckpt_file = "./checkpoint/yolov3_train_loss=%.4f.ckpt" % train_epoch_loss
log_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
print("=> Epoch: %2d Time: %s Train loss: %.2f Test loss: %.2f Saving %s ..."
%(epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
self.saver.save(self.sess, ckpt_file, global_step=epoch)
% (epoch, log_time, train_epoch_loss, test_epoch_loss, ckpt_file))
self.saver.save(self.sess, ckpt_file, global_step=epoch)


if __name__ == '__main__': YoloTrain().train()

0 comments on commit f20dd93

Please sign in to comment.