forked from backtime92/CRAFT-Reimplementation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainSynth.py
138 lines (122 loc) · 4.86 KB
/
trainSynth.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import torch
import torch.optim as optim
import cv2
import time
from data.dataset import SynthTextDataLoader
from craft import CRAFT
from loss.mseloss import Maploss
from torch.autograd import Variable
# TODO:
# 1) find out if the data I'm training on is properly formatted and processed
# 2) fix the resizing/cropping thing that's causing the images with a batch to be
# non-uniform in dimension (which is why I have batch size set to 0 which is slow)
def adjust_learning_rate(optimizer, gamma, step, lr):
"""Sets the learning rate to the initial LR decayed by 10 at every
specified step
# Adapted from PyTorch Imagenet example:
# https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
lr = lr * (gamma**step)
print(lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr
return param_group["lr"]
if __name__ == "__main__":
synthData_dir = {"synthtext": "data/GroundTruth"}
target_size = 768
batch_size = 4
num_workers = 1 # THIS USED TO BE 6, JUST DEBUGGING RN CUH
lr = 1e-4
training_lr = 1e-4
weight_decay = 5e-4
gamma = 0.8
whole_training_step = 100000
synthDataLoader = SynthTextDataLoader(target_size, synthData_dir)
train_loader = torch.utils.data.DataLoader(
synthDataLoader,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=False,
pin_memory=True,
)
craft = (
CRAFT()
) # note: might need to use the pretrain=True parameter here, or whatever its called
craft = torch.nn.DataParallel(craft).cuda()
craft.load_state_dict(torch.load("result/craft_mlt_25k.pth"))
optimizer = optim.Adam(craft.parameters(), lr=lr, weight_decay=weight_decay)
criterion = Maploss()
update_lr_rate_step = 2
train_step = 0
loss_value = 0
batch_time = 0
while train_step < whole_training_step:
for index, (image, region_image, affinity_image, confidence_mask) in enumerate(
train_loader
):
start_time = time.time()
craft.train()
if train_step > 0 and train_step % 20000 == 0:
training_lr = adjust_learning_rate(
optimizer, gamma, update_lr_rate_step, lr
)
update_lr_rate_step += 1
images = Variable(image).cuda()
region_image_label = Variable(region_image).cuda()
affinity_image_label = Variable(affinity_image).cuda()
confidence_mask_label = Variable(confidence_mask).cuda()
output, _ = craft(images)
out1 = output[:, :, :, 0]
out2 = output[:, :, :, 1]
loss = criterion(
region_image_label,
affinity_image_label,
out1,
out2,
confidence_mask_label,
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
end_time = time.time()
loss_value += loss.item()
batch_time += end_time - start_time
if train_step > 0 and train_step % 5 == 0:
mean_loss = loss_value / 5
loss_value = 0
display_batch_time = time.time()
avg_batch_time = batch_time / 5
batch_time = 0
print(
"{}, training_step: {}|{}, learning rate: {:.8f}, training_loss: {:.5f}, avg_batch_time: {:.5f}".format(
time.strftime("%Y-%m-%d:%H:%M:%S", time.localtime(time.time())),
train_step,
whole_training_step,
training_lr,
mean_loss,
avg_batch_time,
)
)
train_step += 1
if train_step % 200 == 0 and train_step != 0:
print("Saving state, index:", index)
torch.save(
craft.state_dict(), "result/CreeAFT_weights_" + repr(index) + ".pth"
)
image_np = image.numpy()[0, :, :, :].transpose(1, 2, 0) * 255
region_image_np = region_image.numpy().transpose(1, 2, 0) * 255
affinity_image_np = affinity_image.numpy().transpose(1, 2, 0) * 255
# for le debugging
print(image_np.shape)
print(region_image_np.shape)
print(affinity_image_np.shape)
cv2.imwrite("image.jpg", image_np)
cv2.imwrite("region_image.jpg", region_image_np)
cv2.imwrite("affinity_image.jpg", affinity_image_np)
exit()
# test('/data/CRAFT-pytorch/synweights/synweights_' + repr(index) + '.pth')
# test('/data/CRAFT-pytorch/craft_mlt_25k.pth')
# getresult()