-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathmodel.py
143 lines (128 loc) · 5.63 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch.nn as nn
from modules.transformation import TPS_SpatialTransformerNetwork
from modules.feature_extraction import (
VGG_FeatureExtractor,
RCNN_FeatureExtractor,
ResNet_FeatureExtractor,
)
from modules.sequence_modeling import BidirectionalLSTM
from modules.prediction import Attention
class Model(nn.Module):
def __init__(self, opt, SelfSL_layer=False):
super(Model, self).__init__()
self.opt = opt
self.stages = {
"Trans": opt.Transformation,
"Feat": opt.FeatureExtraction,
"Seq": opt.SequenceModeling,
"Pred": opt.Prediction,
}
""" Transformation """
if opt.Transformation == "TPS":
self.Transformation = TPS_SpatialTransformerNetwork(
F=opt.num_fiducial,
I_size=(opt.imgH, opt.imgW),
I_r_size=(opt.imgH, opt.imgW),
I_channel_num=opt.input_channel,
)
else:
print("No Transformation module specified")
""" FeatureExtraction """
if opt.FeatureExtraction == "VGG":
self.FeatureExtraction = VGG_FeatureExtractor(
opt.input_channel, opt.output_channel
)
elif opt.FeatureExtraction == "RCNN":
self.FeatureExtraction = RCNN_FeatureExtractor(
opt.input_channel, opt.output_channel
)
elif opt.FeatureExtraction == "ResNet":
self.FeatureExtraction = ResNet_FeatureExtractor(
opt.input_channel, opt.output_channel
)
else:
raise Exception("No FeatureExtraction module specified")
self.FeatureExtraction_output = opt.output_channel
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
(None, 1)
) # Transform final (imgH/16-1) -> 1
if not SelfSL_layer: # for STR
"""Sequence modeling"""
if opt.SequenceModeling == "BiLSTM":
self.SequenceModeling = nn.Sequential(
BidirectionalLSTM(
self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size
),
BidirectionalLSTM(
opt.hidden_size, opt.hidden_size, opt.hidden_size
),
)
self.SequenceModeling_output = opt.hidden_size
else:
print("No SequenceModeling module specified")
self.SequenceModeling_output = self.FeatureExtraction_output
if not SelfSL_layer: # for STR.
"""Prediction"""
if opt.Prediction == "CTC":
self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class)
elif opt.Prediction == "Attn":
self.Prediction = Attention(
self.SequenceModeling_output, opt.hidden_size, opt.num_class
)
else:
raise Exception("Prediction is neither CTC or Attn")
else:
"""for self-supervised learning (SelfSL)"""
self.AdaptiveAvgPool_2 = nn.AdaptiveAvgPool2d((None, 1)) # make width -> 1
if SelfSL_layer == "CNN":
self.SelfSL_FFN_input = self.FeatureExtraction_output
if "RotNet" in self.opt.self:
self.SelfSL = nn.Linear(
self.SelfSL_FFN_input, 4
) # 4 = [0, 90, 180, 270] degrees
elif "MoCo" in self.opt.self:
self.SelfSL = nn.Linear(
self.SelfSL_FFN_input, 128
) # 128 is used for MoCo paper.
def forward(self, image, text=None, is_train=True, SelfSL_layer=False):
"""Transformation stage"""
if not self.stages["Trans"] == "None":
image = self.Transformation(image)
""" Feature extraction stage """
visual_feature = self.FeatureExtraction(image)
visual_feature = visual_feature.permute(
0, 3, 1, 2
) # [b, c, h, w] -> [b, w, c, h]
visual_feature = self.AdaptiveAvgPool(
visual_feature
) # [b, w, c, h] -> [b, w, c, 1]
visual_feature = visual_feature.squeeze(3) # [b, w, c, 1] -> [b, w, c]
""" for self supervised learning on Feature extractor (CNN part) """
if SelfSL_layer == "CNN":
visual_feature = visual_feature.permute(0, 2, 1) # [b, w, c] -> [b, c, w]
visual_feature = self.AdaptiveAvgPool_2(
visual_feature
) # [b, c, w] -> [b, c, 1]
visual_feature = visual_feature.squeeze(2) # [b, c, 1] -> [b, c]
prediction_SelfSL = self.SelfSL(
visual_feature
) # [b, c] -> [b, SelfSL_class]
return prediction_SelfSL
""" Sequence modeling stage """
if self.stages["Seq"] == "BiLSTM":
contextual_feature = self.SequenceModeling(
visual_feature
) # [b, num_steps, opt.hidden_size]
else:
contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM
""" Prediction stage """
if self.stages["Pred"] == "CTC":
prediction = self.Prediction(contextual_feature.contiguous())
else:
prediction = self.Prediction(
contextual_feature.contiguous(),
text,
is_train,
batch_max_length=self.opt.batch_max_length,
)
return prediction # [b, num_steps, opt.num_class]