diff --git a/train.py b/train.py index 02ec80074..b417399f4 100644 --- a/train.py +++ b/train.py @@ -103,6 +103,8 @@ def _str_to_bool(s): parser.add_argument('--max_checkpoints', type=int, default=MAX_TO_KEEP, help='Maximum amount of checkpoints that will be kept alive. Default: ' + str(MAX_TO_KEEP) + '.') + parser.add_argument('--automatic_mixed_precision', type=_str_to_bool, default=False, + help='Using automatic mixed precision training') return parser.parse_args() @@ -256,6 +258,8 @@ def main(): optimizer = optimizer_factory[args.optimizer]( learning_rate=args.learning_rate, momentum=args.momentum) + if args.automatic_mixed_precision: + optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) trainable = tf.trainable_variables() optim = optimizer.minimize(loss, var_list=trainable)