Skip to content

Commit

Permalink
Learning rate decay (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric8932 authored Oct 11, 2023
1 parent 554c078 commit f3d0f2c
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tencentpretrain/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def optimization_opts(parser):
help="Learning rate.")
parser.add_argument("--warmup", type=float, default=0.1,
help="Warm up value.")
parser.add_argument("--decay", type=float, default=0.5,
help="decay value.")
parser.add_argument("--lr_decay", type=float, default=0.5,
help="Learning rate decay value.")
parser.add_argument("--optimizer", choices=["adamw", "adafactor"],
default="adamw",
help="Optimizer type.")
Expand Down
2 changes: 1 addition & 1 deletion tencentpretrain/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ def worker(local_rank, gpu_ranks, args, model_for_training, model_for_dataloader
elif args.scheduler in ["constant_with_warmup"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup)
elif args.scheduler in ["tri_stage"]:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.decay, args.total_steps)
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps*args.lr_decay, args.total_steps)
else:
custom_scheduler = str2scheduler[args.scheduler](custom_optimizer, args.total_steps*args.warmup, args.total_steps)

Expand Down
3 changes: 2 additions & 1 deletion tencentpretrain/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,6 @@
"get_linear_schedule_with_warmup", "get_cosine_schedule_with_warmup",
"get_cosine_with_hard_restarts_schedule_with_warmup",
"get_polynomial_decay_schedule_with_warmup",
"get_constant_schedule", "get_constant_schedule_with_warmup", "str2scheduler",
"get_constant_schedule", "get_constant_schedule_with_warmup",
"get_inverse_square_root_schedule_with_warmup", "get_tri_stage_schedule", "str2scheduler",
"FGM", "PGD", "str2adv"]

0 comments on commit f3d0f2c

Please sign in to comment.