Skip to content

Commit

Permalink
[Fix] sar_resnet31 TF + PT (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Mar 15, 2024
1 parent b39a7ef commit c2b197d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 14 deletions.
4 changes: 4 additions & 0 deletions docs/source/using_doctr/using_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| TensorFlow | db_resnet50 + master | 72.73 | 74.00 | 84.13 | 75.05 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| TensorFlow | db_resnet50 + sar_resnet31 | 73.23 | 74.51 | 85.34 | 76.03 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| TensorFlow | db_resnet50 + vitstr_small | 68.57 | 69.77 | 78.24 | 69.51 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| TensorFlow | db_resnet50 + vitstr_base | 70.96 | 72.20 | 82.10 | 72.94 |
Expand All @@ -242,6 +244,8 @@ For a comprehensive comparison, we have compiled a detailed benchmark on publicl
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| PyTorch | db_resnet50 + master | 73.90 | 76.66 | 85.84 | 80.07 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| PyTorch | db_resnet50 + sar_resnet31 | 73.58 | 76.33 | 85.64 | 79.88 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| PyTorch | db_resnet50 + vitstr_small | 73.06 | 75.79 | 85.95 | 80.17 |
+----------------+----------------------------------------------------------+------------+---------------+------------+---------------+
| PyTorch | db_resnet50 + vitstr_base | 73.70 | 76.46 | 85.76 | 79.99 |
Expand Down
11 changes: 6 additions & 5 deletions doctr/models/recognition/sar/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,25 +125,26 @@ def forward(
if t == 0:
# step to init the first states of the LSTMCell
hidden_state_init = cell_state_init = torch.zeros(
features.size(0), features.size(1), device=features.device
features.size(0), features.size(1), device=features.device, dtype=features.dtype
)
hidden_state, cell_state = hidden_state_init, cell_state_init
prev_symbol = holistic
elif t == 1:
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
# (N, vocab_size + 1) --> (N, embedding_units)
prev_symbol = torch.zeros(features.size(0), self.vocab_size + 1, device=features.device)
prev_symbol = torch.zeros(
features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
)
prev_symbol = self.embed(prev_symbol)
else:
if gt is not None:
if gt is not None and self.training:
# (N, embedding_units) -2 because of <bos> and <eos> (same)
prev_symbol = self.embed(gt_embedding[:, t - 2])
else:
# -1 to start at timestep where prev_symbol was initialized
index = logits_list[t - 1].argmax(-1)
# update prev_symbol with ones at the index of the previous logit vector
# (N, embedding_units)
prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)
prev_symbol = self.embed(self.embed_tgt(index))

# (N, C), (N, C) take the last hidden state and cell state from current timestep
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
Expand Down
12 changes: 3 additions & 9 deletions doctr/models/recognition/sar/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,23 +177,17 @@ def call(
elif t == 1:
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
# (N, vocab_size + 1) --> (N, embedding_units)
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1])
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
prev_symbol = self.embed(prev_symbol, **kwargs)
else:
if gt is not None:
if gt is not None and kwargs.get("training", False):
# (N, embedding_units) -2 because of <bos> and <eos> (same)
prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
else:
# -1 to start at timestep where prev_symbol was initialized
index = tf.argmax(logits_list[t - 1], axis=-1)
# update prev_symbol with ones at the index of the previous logit vector
# (N, embedding_units)
index = tf.ones_like(index)
prev_symbol = tf.scatter_nd(
tf.expand_dims(index, axis=1),
prev_symbol,
tf.constant([features.shape[0], features.shape[-1]], dtype=tf.int64),
)
prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)

# (N, C), (N, C) take the last hidden state and cell state from current timestep
_, states = self.lstm_cells(prev_symbol, states, **kwargs)
Expand Down

0 comments on commit c2b197d

Please sign in to comment.