diff --git a/laia/callbacks/decode.py b/laia/callbacks/decode.py index 2bd9725f..bfb920a7 100644 --- a/laia/callbacks/decode.py +++ b/laia/callbacks/decode.py @@ -70,9 +70,14 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args): img_ids = pl_module.batch_id_fn(batch) hyps = self.decoder(outputs)["hyp"] - if self.print_confidence_scores: + line_probs = ( + self.decoder(outputs)["prob-htr"] + if self.print_line_confidence_scores + else [] + ) + + if self.print_word_confidence_scores: probs = self.decoder(outputs)["prob-htr-char"] - line_probs = [np.mean(prob) for prob in probs] word_probs = [ compute_word_prob(self.syms, hyp, prob, self.input_space) for hyp, prob in zip(hyps, probs) @@ -80,7 +85,6 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args): else: probs = [] - line_probs = [] word_probs = [] for i, (img_id, hyp) in enumerate(zip(img_ids, hyps)): diff --git a/laia/decoders/ctc_greedy_decoder.py b/laia/decoders/ctc_greedy_decoder.py index afeceba5..dc797ea9 100644 --- a/laia/decoders/ctc_greedy_decoder.py +++ b/laia/decoders/ctc_greedy_decoder.py @@ -75,6 +75,7 @@ def __call__( # Return char-based probability out["prob-htr-char"] = [prob.tolist() for prob in probs] + out["prob-htr"] = [prob.mean().item() for prob in probs] return out @staticmethod diff --git a/laia/decoders/ctc_language_decoder.py b/laia/decoders/ctc_language_decoder.py index b14cf588..9359efa1 100644 --- a/laia/decoders/ctc_language_decoder.py +++ b/laia/decoders/ctc_language_decoder.py @@ -44,6 +44,7 @@ def __init__( sil_token=sil_token, ) self.temperature = temperature + self.language_model_weight = language_model_weight def __call__( self, @@ -85,5 +86,12 @@ def __call__( # Format the output out = {} out["hyp"] = [hypothesis[0].tokens.tolist() for hypothesis in hypotheses] - # you can get a log likelihood with hypothesis[0].score + + # Normalize confidence score + out["prob-htr"] = [ + np.exp( + hypothesis[0].score / ((self.language_model_weight + 1) * length.item()) + ) + for hypothesis, length in zip(hypotheses, batch_sizes) + ] return out diff --git a/laia/scripts/htr/decode_ctc.py b/laia/scripts/htr/decode_ctc.py index 9d695e26..62a537e7 100755 --- a/laia/scripts/htr/decode_ctc.py +++ b/laia/scripts/htr/decode_ctc.py @@ -73,8 +73,7 @@ def run( sil_token=decode.input_space, temperature=decode.temperature, ) - # confidence scores are not supported when using a language model - decode.print_line_confidence_scores = False + # word-level confidence scores are not supported when using a language model decode.print_word_confidence_scores = False else: diff --git a/tests/callbacks/decode_test.py b/tests/callbacks/decode_test.py index c2750970..4990c0b6 100644 --- a/tests/callbacks/decode_test.py +++ b/tests/callbacks/decode_test.py @@ -15,6 +15,7 @@ def __call__(self, batch_y_hat): "prob-htr-char": [ [0.9, 0.9, 0.9, 0.9, 0.9, 0.9] for _ in range(batch_size) ], + "prob-htr": [0.9 for _ in range(batch_size)], } diff --git a/tests/decoders/ctc_language_decoder_test.py b/tests/decoders/ctc_language_decoder_test.py index c0ac159a..2da2439f 100644 --- a/tests/decoders/ctc_language_decoder_test.py +++ b/tests/decoders/ctc_language_decoder_test.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import pytest import torch @@ -63,7 +64,7 @@ @pytest.mark.parametrize( - ["input_tensor", "lm_weight", "expected_string"], + ["input_tensor", "lm_weight", "expected_string", "expected_confidence"], [ ( # Simulate a feature vector of size (n_frame=4, batch_size=1, n_tokens=10) @@ -78,6 +79,7 @@ ), 0, "tast", + 0.38, ), ( # For frame 2, tokens with index 1 and 2 are the most probable tokens (the model a feature vector of size (n_frame=4, batch_size=1, n_tokens=10) @@ -91,10 +93,13 @@ ), 1, "test", + 0.40, ), ], ) -def test_lm_decoding_weight(tmpdir, input_tensor, lm_weight, expected_string): +def test_lm_decoding_weight( + tmpdir, input_tensor, lm_weight, expected_string, expected_confidence +): tokens_path = Path(tmpdir) / "tokens.txt" lexicon_path = Path(tmpdir) / "lexicon.txt" arpa_path = Path(tmpdir) / "lm.arpa" @@ -115,5 +120,7 @@ def test_lm_decoding_weight(tmpdir, input_tensor, lm_weight, expected_string): predicted_tokens = decoder(input_tensor)["hyp"][0] predicted_string = "".join([tokens_char[char] for char in predicted_tokens]) predicted_string = predicted_string.replace("", " ").strip() + predicted_confidence = decoder(input_tensor)["prob-htr"][0] assert predicted_string == expected_string + assert np.around(predicted_confidence, 2) == expected_confidence