diff --git a/doctr/models/recognition/parseq/pytorch.py b/doctr/models/recognition/parseq/pytorch.py index c3a21cd578..8fff062da9 100644 --- a/doctr/models/recognition/parseq/pytorch.py +++ b/doctr/models/recognition/parseq/pytorch.py @@ -282,7 +282,8 @@ def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = ys[:, i + 1] = pos_prob.squeeze().argmax(-1) # Stop decoding if all sequences have reached the EOS token - if max_len is None and (ys == self.vocab_size).any(dim=-1).all(): + # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export + if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all(): break logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1) diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index c1904c082e..1365a6ac12 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -288,10 +288,11 @@ def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = No ) # Stop decoding if all sequences have reached the EOS token - # We need to check it on True to be compatible with ONNX + # NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export if ( - max_len is None - and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True + not self.exportable + and max_len is None + and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) ): break diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py index 5f8ef1a44f..e4df34060b 100644 --- a/tests/pytorch/test_models_recognition_pt.py +++ b/tests/pytorch/test_models_recognition_pt.py @@ -148,7 +148,7 @@ def test_models_onnx_export(arch_name, input_shape): ort_outs = ort_session.run(["logits"], {"input": dummy_input.numpy()}) assert isinstance(ort_outs, list) and len(ort_outs) == 1 - assert ort_outs[0].shape[0] == batch_size + assert ort_outs[0].shape == pt_logits.shape # Check that the output is close to the PyTorch output - only warn if not close try: assert np.allclose(pt_logits, ort_outs[0], atol=1e-4) diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py index 3e0ba5b8f0..b58272d1de 100644 --- a/tests/tensorflow/test_models_recognition_tf.py +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -226,7 +226,7 @@ def test_models_onnx_export(arch_name, input_shape): ort_outs = ort_session.run(output, {"input": np_dummy_input}) assert isinstance(ort_outs, list) and len(ort_outs) == 1 - assert ort_outs[0].shape[0] == batch_size + assert ort_outs[0].shape == tf_logits.shape # Check that the output is close to the TensorFlow output - only warn if not close try: assert np.allclose(tf_logits, ort_outs[0], atol=1e-4)