From b0abb297b706b8fa02526847f075d8d0aa8f4424 Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Fri, 16 Aug 2019 00:13:05 +1000 Subject: [PATCH 1/2] adding automatic mixed precision training support --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index 02ec80074..6298a57dc 100644 --- a/train.py +++ b/train.py @@ -103,6 +103,8 @@ def _str_to_bool(s): parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP, help='Maximum amount of checkpoints that will be kept alive. Default: ' + str(MAX_TO_KEEP) + '.') + parser.add_argument('--automatic_mixed_precision', type=_str_to_bool, default=False, + help='Using automatic mixed precision training') return parser.parse_args() @@ -256,6 +258,8 @@ def main(): optimizer = optimizer_factory[args.optimizer]( learning_rate=args.learning_rate, momentum=args.momentum) + if args.auto_mixed_precision: + optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable) From 048500ef7ac793a2ad057b016df145b070ca672d Mon Sep 17 00:00:00 2001 From: Vinh Nguyen Date: Fri, 16 Aug 2019 09:20:11 +1000 Subject: [PATCH 2/2] adding automatic mixed precision training support-fix param naming --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 6298a57dc..b417399f4 100644 --- a/train.py +++ b/train.py @@ -258,7 +258,7 @@ def main(): optimizer = optimizer_factory[args.optimizer]( learning_rate=args.learning_rate, momentum=args.momentum) - if args.auto_mixed_precision: + if args.automatic_mixed_precision: optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable)