Skip to content

Commit

Permalink
Merge pull request #70 from jpuigcerver/confidence-score-language-model
Browse files Browse the repository at this point in the history
Normalize score after LM decoding
  • Loading branch information
yschneider-sinneria authored Oct 9, 2023
2 parents f0c4333 + eb04c71 commit 3d554b9
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 8 deletions.
10 changes: 7 additions & 3 deletions laia/callbacks/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,21 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args):
img_ids = pl_module.batch_id_fn(batch)
hyps = self.decoder(outputs)["hyp"]

if self.print_confidence_scores:
line_probs = (
self.decoder(outputs)["prob-htr"]
if self.print_line_confidence_scores
else []
)

if self.print_word_confidence_scores:
probs = self.decoder(outputs)["prob-htr-char"]
line_probs = [np.mean(prob) for prob in probs]
word_probs = [
compute_word_prob(self.syms, hyp, prob, self.input_space)
for hyp, prob in zip(hyps, probs)
]

else:
probs = []
line_probs = []
word_probs = []

for i, (img_id, hyp) in enumerate(zip(img_ids, hyps)):
Expand Down
1 change: 1 addition & 0 deletions laia/decoders/ctc_greedy_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __call__(

# Return char-based probability
out["prob-htr-char"] = [prob.tolist() for prob in probs]
out["prob-htr"] = [prob.mean().item() for prob in probs]
return out

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion laia/decoders/ctc_language_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
sil_token=sil_token,
)
self.temperature = temperature
self.language_model_weight = language_model_weight

def __call__(
self,
Expand Down Expand Up @@ -85,5 +86,12 @@ def __call__(
# Format the output
out = {}
out["hyp"] = [hypothesis[0].tokens.tolist() for hypothesis in hypotheses]
# you can get a log likelihood with hypothesis[0].score

# Normalize confidence score
out["prob-htr"] = [
np.exp(
hypothesis[0].score / ((self.language_model_weight + 1) * length.item())
)
for hypothesis, length in zip(hypotheses, batch_sizes)
]
return out
3 changes: 1 addition & 2 deletions laia/scripts/htr/decode_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def run(
sil_token=decode.input_space,
temperature=decode.temperature,
)
# confidence scores are not supported when using a language model
decode.print_line_confidence_scores = False
# word-level confidence scores are not supported when using a language model
decode.print_word_confidence_scores = False

else:
Expand Down
1 change: 1 addition & 0 deletions tests/callbacks/decode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __call__(self, batch_y_hat):
"prob-htr-char": [
[0.9, 0.9, 0.9, 0.9, 0.9, 0.9] for _ in range(batch_size)
],
"prob-htr": [0.9 for _ in range(batch_size)],
}


Expand Down
11 changes: 9 additions & 2 deletions tests/decoders/ctc_language_decoder_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import numpy as np
import pytest
import torch

Expand Down Expand Up @@ -63,7 +64,7 @@


@pytest.mark.parametrize(
["input_tensor", "lm_weight", "expected_string"],
["input_tensor", "lm_weight", "expected_string", "expected_confidence"],
[
(
# Simulate a feature vector of size (n_frame=4, batch_size=1, n_tokens=10)
Expand All @@ -78,6 +79,7 @@
),
0,
"tast",
0.38,
),
(
# For frame 2, tokens with index 1 and 2 are the most probable tokens (the model a feature vector of size (n_frame=4, batch_size=1, n_tokens=10)
Expand All @@ -91,10 +93,13 @@
),
1,
"test",
0.40,
),
],
)
def test_lm_decoding_weight(tmpdir, input_tensor, lm_weight, expected_string):
def test_lm_decoding_weight(
tmpdir, input_tensor, lm_weight, expected_string, expected_confidence
):
tokens_path = Path(tmpdir) / "tokens.txt"
lexicon_path = Path(tmpdir) / "lexicon.txt"
arpa_path = Path(tmpdir) / "lm.arpa"
Expand All @@ -115,5 +120,7 @@ def test_lm_decoding_weight(tmpdir, input_tensor, lm_weight, expected_string):
predicted_tokens = decoder(input_tensor)["hyp"][0]
predicted_string = "".join([tokens_char[char] for char in predicted_tokens])
predicted_string = predicted_string.replace("<space>", " ").strip()
predicted_confidence = decoder(input_tensor)["prob-htr"][0]

assert predicted_string == expected_string
assert np.around(predicted_confidence, 2) == expected_confidence

0 comments on commit 3d554b9

Please sign in to comment.