Skip to content

Commit

Permalink
feat: Added proper kernel sizing to silence loading warnings (#152)
Browse files Browse the repository at this point in the history
* feat: Specified input_shape in DBNet

* feat: Added proper kernel init in CRNN

* feat: Updated DB & CRNN checkpoints

* fix: Fixed URL

* feat: Initialized kernels of SAR

* feat: Initialized kernels of SAR

* feat: Updated SAR checkpoints
  • Loading branch information
fg-mindee authored Mar 17, 2021
1 parent 377d624 commit 19a8667
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
9 changes: 6 additions & 3 deletions doctr/models/detection/differentiable_binarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'fpn_channels': 128,
'input_shape': (1024, 1024, 3),
'post_processor': 'DBPostProcessor',
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/db_resnet50-df8d0071.zip',
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/db_resnet50-091c08a5.zip',
},
}

Expand Down Expand Up @@ -275,10 +275,13 @@ def __init__(
self.feat_extractor = feature_extractor

self.fpn = FeaturePyramidNetwork(channels=fpn_channels)
# Initialize kernels
_inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape]
output_shape = tuple(self.fpn(_inputs).shape)

self.probability_head = keras.Sequential(
[
*conv_sequence(64, 'relu', True, kernel_size=3),
*conv_sequence(64, 'relu', True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Activation('relu'),
Expand All @@ -288,7 +291,7 @@ def __init__(
)
self.threshold_head = keras.Sequential(
[
*conv_sequence(64, 'relu', True, kernel_size=3),
*conv_sequence(64, 'relu', True, kernel_size=3, input_shape=output_shape[1:]),
layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'),
layers.BatchNormalization(),
layers.Activation('relu'),
Expand Down
8 changes: 6 additions & 2 deletions doctr/models/recognition/crnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
'post_processor': 'CTCPostProcessor',
'vocab': ('3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£|za1ù8,OG€P-'
'kçHëÀÂ2É/ûIJ\'j(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l'),
'url': 'https://github.com/mindee/doctr/releases/download/v0.1-models/crnn_vgg16bn-b37097a8.zip',
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/crnn_vgg16_bn-f29aa0aa.zip',
},
'crnn_resnet31': {
'mean': (0.694, 0.695, 0.693),
Expand Down Expand Up @@ -57,9 +57,13 @@ def __init__(
) -> None:
super().__init__(cfg=cfg)
self.feat_extractor = feature_extractor

# Initialize kernels
h, w, c = self.feat_extractor.output_shape[1:]

self.decoder = Sequential(
[
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True, input_shape=(w, h * c))),
layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)),
layers.Dense(units=vocab_size + 1)
]
Expand Down
14 changes: 8 additions & 6 deletions doctr/models/recognition/sar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
'post_processor': 'SARPostProcessor',
'vocab': ('3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£|za1ù8,OG€P-'
'kçHëÀÂ2É/ûIJ\'j(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l'),
'url': 'https://github.com/mindee/doctr/releases/download/v0.1-models/sar_vgg16bn-0d7e2c26.zip',
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/sar_vgg16_bn-e0be6df9.zip',
},
'sar_resnet31': {
'mean': (.5, .5, .5),
Expand All @@ -35,7 +35,7 @@
'post_processor': 'SARPostProcessor',
'vocab': ('3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£|za1ù8,OG€P-'
'kçHëÀÂ2É/ûIJ\'j(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l'),
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/sar_resnet31-ea202587.zip',
'url': 'https://github.com/mindee/doctr/releases/download/v0.1.0/sar_resnet31-sha4182d.zip',
},
}

Expand Down Expand Up @@ -114,13 +114,13 @@ def __init__(

super().__init__()
self.vocab_size = vocab_size
self.embed = layers.Dense(embedding_units, use_bias=False)
self.attention_module = AttentionModule(attention_units)
self.output_dense = layers.Dense(vocab_size + 1, use_bias=True)
self.max_length = max_length
self.embed = layers.Dense(embedding_units, use_bias=False, input_shape=(self.vocab_size + 1,))
self.lstm_decoder = layers.StackedRNNCells(
[layers.LSTMCell(rnn_units, dtype=tf.float32, implementation=1) for _ in range(num_decoder_layers)]
)
self.attention_module = AttentionModule(attention_units)
self.output_dense = layers.Dense(vocab_size + 1, use_bias=True, input_shape=(2 * rnn_units,))
self.max_length = max_length

def call(
self,
Expand Down Expand Up @@ -211,6 +211,8 @@ def __init__(
layers.LSTM(units=rnn_units, return_sequences=False)
]
)
# Initialize the kernels (watch out for reduce_max)
self.encoder.build(input_shape=(None,) + self.feat_extractor.output_shape[2:])

self.decoder = SARDecoder(
rnn_units, max_length, vocab_size, embedding_units, attention_units, num_decoders,
Expand Down

0 comments on commit 19a8667

Please sign in to comment.