-
Notifications
You must be signed in to change notification settings - Fork 7
/
utils.py
19 lines (17 loc) · 822 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import math
def get_cosine_schedule_with_warmup(optimizer,
num_training_steps,
num_cycles=7. / 16.,
num_warmup_steps=0,
last_epoch=-1):
from torch.optim.lr_scheduler import LambdaLR
def _lr_lambda(current_step):
if current_step < num_warmup_steps:
_lr = float(current_step) / float(max(1, num_warmup_steps))
else:
num_cos_steps = float(current_step - num_warmup_steps)
num_cos_steps = num_cos_steps / float(max(1, num_training_steps - num_warmup_steps))
_lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps))
return _lr
return LambdaLR(optimizer, _lr_lambda, last_epoch)