Skip to content

Commit

Permalink
fix parseq onnx exported (#1585)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored May 6, 2024
1 parent 56db176 commit 1a7e49c
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
3 changes: 2 additions & 1 deletion doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] =
ys[:, i + 1] = pos_prob.squeeze().argmax(-1)

# Stop decoding if all sequences have reached the EOS token
if max_len is None and (ys == self.vocab_size).any(dim=-1).all():
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
break

logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
Expand Down
7 changes: 4 additions & 3 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,11 @@ def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = No
)

# Stop decoding if all sequences have reached the EOS token
# We need to check it on True to be compatible with ONNX
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
if (
max_len is None
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
not self.exportable
and max_len is None
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
):
break

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_models_recognition_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_models_onnx_export(arch_name, input_shape):
ort_outs = ort_session.run(["logits"], {"input": dummy_input.numpy()})

assert isinstance(ort_outs, list) and len(ort_outs) == 1
assert ort_outs[0].shape[0] == batch_size
assert ort_outs[0].shape == pt_logits.shape
# Check that the output is close to the PyTorch output - only warn if not close
try:
assert np.allclose(pt_logits, ort_outs[0], atol=1e-4)
Expand Down
2 changes: 1 addition & 1 deletion tests/tensorflow/test_models_recognition_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_models_onnx_export(arch_name, input_shape):
ort_outs = ort_session.run(output, {"input": np_dummy_input})

assert isinstance(ort_outs, list) and len(ort_outs) == 1
assert ort_outs[0].shape[0] == batch_size
assert ort_outs[0].shape == tf_logits.shape
# Check that the output is close to the TensorFlow output - only warn if not close
try:
assert np.allclose(tf_logits, ort_outs[0], atol=1e-4)
Expand Down

0 comments on commit 1a7e49c

Please sign in to comment.