forked from Huangdebo/CAWB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcawb.py
152 lines (110 loc) · 4.74 KB
/
cawb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 6 19:10:49 2021
@author: hdb
"""
import torch.optim as optim
import torch
import torch.nn as nn
import argparse
import math
from copy import copy
import matplotlib.pyplot as plt
class CosineAnnealingWarmbootingLR:
# cawb learning rate scheduler: given the warm booting steps, calculate the learning rate automatically
def __init__(self, optimizer, epochs=0, eta_min=0.05, steps=[], step_scale=0.8, lf=None, batchs=0, warmup_epoch=0, epoch_scale=1.0):
self.warmup_iters = batchs * warmup_epoch
self.optimizer = optimizer
self.eta_min = eta_min
self.iters = -1
self.iters_batch = -1
self.base_lr = [group['lr'] for group in optimizer.param_groups]
self.step_scale = step_scale
steps.sort()
self.steps = [warmup_epoch] + [i for i in steps if (i < epochs and i > warmup_epoch)] + [epochs]
self.gap = 0
self.last_epoch = 0
self.lf = lf
self.epoch_scale = epoch_scale
# Initialize epochs and base learning rates
for group in optimizer.param_groups:
group.setdefault('initial_lr', group['lr'])
def step(self, external_iter = None):
self.iters += 1
if external_iter is not None:
self.iters = external_iter
# cos warm boot policy
iters = self.iters + self.last_epoch
scale = 1.0
for i in range(len(self.steps)-1):
if (iters <= self.steps[i+1]):
self.gap = self.steps[i+1] - self.steps[i]
iters = iters - self.steps[i]
if i != len(self.steps)-2:
self.gap += self.epoch_scale
break
scale *= self.step_scale
if self.lf is None:
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
group['lr'] = scale * lr * ((((1 + math.cos(iters * math.pi / self.gap)) / 2) ** 1.0) * (1.0 - self.eta_min) + self.eta_min)
else:
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
group['lr'] = scale * lr * self.lf(iters, self.gap)
return self.optimizer.param_groups[0]['lr']
def step_batch(self):
self.iters_batch += 1
if self.iters_batch < self.warmup_iters:
rate = self.iters_batch / self.warmup_iters
for group, lr in zip(self.optimizer.param_groups, self.base_lr):
group['lr'] = lr * rate
return self.optimizer.param_groups[0]['lr']
else:
return None
def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir='./LR.png'):
# Plot LR simulating training for full epochs
optimizer, scheduler = copy(optimizer), copy(scheduler) # do not modify originals
y = []
for _ in range(scheduler.last_epoch):
y.append(None)
for _ in range(scheduler.last_epoch, epochs):
y.append(scheduler.step())
plt.plot(y, '.-', label='LR')
plt.xlabel('epoch')
plt.ylabel('LR')
plt.grid()
plt.xlim(0, epochs)
plt.ylim(0)
plt.tight_layout()
plt.savefig(save_dir, dpi=200)
class model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3,3,3)
def forward(self, x):
return self.conv(x)
def train(opt):
net = model()
data = [1] * 100
optimizer = optim.Adam(net.parameters(), lr=0.1)
lf = lambda x, y=opt.epochs: (((1 + math.cos(x * math.pi / y)) / 2) ** 1.0) * 0.8 + 0.2
# lf = lambda x, y=opt.epochs: (1.0 - (x / y)) * 0.9 + 0.1
scheduler = CosineAnnealingWarmbootingLR(optimizer, epochs=opt.epochs, steps=opt.cawb_steps, step_scale=0.7,
lf=lf, batchs=len(data), warmup_epoch=0)
# last_epoch = 20
# scheduler.last_epoch = last_epoch # if resume from given model
# plot_lr_scheduler(optimizer, scheduler, opt.epochs) # 目前不能画出 warmup 的曲线
for i in range(opt.epochs):
for b in range(len(data)):
lr = scheduler.step_batch() # defore the backward
# training
# loss
# backward
scheduler.step()
return 0
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=150)
parser.add_argument('--scheduler_lr', type=str, default='cawb', help='the learning rate scheduler, cos/cawb')
parser.add_argument('--cawb_steps', nargs='+', type=int, default=[50, 100, 150], help='the cawb learning rate scheduler steps')
opt = parser.parse_args()
train(opt)