diff --git a/tencentpretrain/opts.py b/tencentpretrain/opts.py index eec3b73..91a6d28 100755 --- a/tencentpretrain/opts.py +++ b/tencentpretrain/opts.py @@ -102,8 +102,8 @@ def optimization_opts(parser): help="Learning rate.") parser.add_argument("--warmup", type=float, default=0.1, help="Warm up value.") - parser.add_argument("--decay", type=float, default=0.5, - help="decay value.") + parser.add_argument("--lr_decay", type=float, default=0.5, + help="Learning rate decay value.") parser.add_argument("--optimizer", choices=["adamw", "adafactor"], default="adamw", help="Optimizer type.") diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 6fdb92c..3b768a3 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -610,7 +610,7 @@ def worker(local_rank, gpu_ranks, args, model_for_training, model_for_dataloader elif args.scheduler in ["constant_with_warmup"]: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup) elif args.scheduler in ["tri_stage"]: - custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.decay, args.total_steps) + custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps) else: custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps) diff --git a/tencentpretrain/utils/__init__.py b/tencentpretrain/utils/__init__.py index aa29517..f7512a2 100644 --- a/tencentpretrain/utils/__init__.py +++ b/tencentpretrain/utils/__init__.py @@ -48,5 +48,6 @@ "get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup", "get_cosine_with_hard_restarts_schedule_with_warmup", "get_polynomial_decay_schedule_with_warmup", - "get_constant_schedule", "get_constant_schedule_with_warmup", "str2scheduler", + "get_constant_schedule", "get_constant_schedule_with_warmup", + "get_inverse_square_root_schedule_with_warmup", "get_tri_stage_schedule", "str2scheduler", "FGM", "PGD", "str2adv"]