Skip to content

Commit

Permalink
Replace unidecode with anyascii (#1509)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatankawalek authored Mar 12, 2024
1 parent 60d4005 commit 709404e
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 36 deletions.
2 changes: 1 addition & 1 deletion .conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ requirements:
- weasyprint >=55.0
- defusedxml >=0.7.0
- mplcursors >=0.3
- unidecode >=1.0.0
- anyascii >=0.3.2
- tqdm >=4.30.0

test:
Expand Down
32 changes: 16 additions & 16 deletions doctr/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import cv2
import numpy as np
from anyascii import anyascii
from scipy.optimize import linear_sum_assignment
from unidecode import unidecode

__all__ = [
"TextMatch",
Expand All @@ -34,16 +34,16 @@ def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]:
Returns:
-------
a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their
unidecode counterparts and their lower-case unidecode counterparts match
anyascii counterparts and their lower-case anyascii counterparts match
"""
raw_match = word1 == word2
caseless_match = word1.lower() == word2.lower()
unidecode_match = unidecode(word1) == unidecode(word2)
anyascii_match = anyascii(word1) == anyascii(word2)

# Warning: the order is important here otherwise the pair ("EUR", "€") cannot be matched
unicase_match = unidecode(word1).lower() == unidecode(word2).lower()
unicase_match = anyascii(word1).lower() == anyascii(word2).lower()

return raw_match, caseless_match, unidecode_match, unicase_match
return raw_match, caseless_match, anyascii_match, unicase_match


class TextMatch:
Expand Down Expand Up @@ -94,10 +94,10 @@ def update(
raise AssertionError("prediction size does not match with ground-truth labels size")

for gt_word, pred_word in zip(gt, pred):
_raw, _caseless, _unidecode, _unicase = string_match(gt_word, pred_word)
_raw, _caseless, _anyascii, _unicase = string_match(gt_word, pred_word)
self.raw += int(_raw)
self.caseless += int(_caseless)
self.unidecode += int(_unidecode)
self.anyascii += int(_anyascii)
self.unicase += int(_unicase)

self.total += len(gt)
Expand All @@ -107,23 +107,23 @@ def summary(self) -> Dict[str, float]:
Returns
-------
a dictionary with the exact match score for the raw data, its lower-case counterpart, its unidecode
counterpart and its lower-case unidecode counterpart
a dictionary with the exact match score for the raw data, its lower-case counterpart, its anyascii
counterpart and its lower-case anyascii counterpart
"""
if self.total == 0:
raise AssertionError("you need to update the metric before getting the summary")

return dict(
raw=self.raw / self.total,
caseless=self.caseless / self.total,
unidecode=self.unidecode / self.total,
anyascii=self.anyascii / self.total,
unicase=self.unicase / self.total,
)

def reset(self) -> None:
self.raw = 0
self.caseless = 0
self.unidecode = 0
self.anyascii = 0
self.unicase = 0
self.total = 0

Expand Down Expand Up @@ -544,10 +544,10 @@ def update(
is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh
# String comparison
for gt_idx, pred_idx in zip(gt_indices[is_kept], pred_indices[is_kept]):
_raw, _caseless, _unidecode, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
_raw, _caseless, _anyascii, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx])
self.raw_matches += int(_raw)
self.caseless_matches += int(_caseless)
self.unidecode_matches += int(_unidecode)
self.anyascii_matches += int(_anyascii)
self.unicase_matches += int(_unicase)

self.num_gts += gt_boxes.shape[0]
Expand All @@ -564,15 +564,15 @@ def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]
recall = dict(
raw=self.raw_matches / self.num_gts if self.num_gts > 0 else None,
caseless=self.caseless_matches / self.num_gts if self.num_gts > 0 else None,
unidecode=self.unidecode_matches / self.num_gts if self.num_gts > 0 else None,
anyascii=self.anyascii_matches / self.num_gts if self.num_gts > 0 else None,
unicase=self.unicase_matches / self.num_gts if self.num_gts > 0 else None,
)

# Precision
precision = dict(
raw=self.raw_matches / self.num_preds if self.num_preds > 0 else None,
caseless=self.caseless_matches / self.num_preds if self.num_preds > 0 else None,
unidecode=self.unidecode_matches / self.num_preds if self.num_preds > 0 else None,
anyascii=self.anyascii_matches / self.num_preds if self.num_preds > 0 else None,
unicase=self.unicase_matches / self.num_preds if self.num_preds > 0 else None,
)

Expand All @@ -587,7 +587,7 @@ def reset(self) -> None:
self.tot_iou = 0.0
self.raw_matches = 0
self.caseless_matches = 0
self.unidecode_matches = 0
self.anyascii_matches = 0
self.unicase_matches = 0


Expand Down
10 changes: 5 additions & 5 deletions doctr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import matplotlib.pyplot as plt
import mplcursors
import numpy as np
from anyascii import anyascii
from matplotlib.figure import Figure
from PIL import Image, ImageDraw
from unidecode import unidecode

from .common_types import BoundingBox, Polygon4P
from .fonts import get_font
Expand Down Expand Up @@ -327,8 +327,8 @@ def synthesize_page(
try:
d.text((0, 0), word["value"], font=font, fill=(0, 0, 0))
except UnicodeEncodeError:
# When character cannot be encoded, use its unidecode version
d.text((0, 0), unidecode(word["value"]), font=font, fill=(0, 0, 0))
# When character cannot be encoded, use its anyascii version
d.text((0, 0), anyascii(word["value"]), font=font, fill=(0, 0, 0))

# Colorize if draw_proba
if draw_proba:
Expand Down Expand Up @@ -458,8 +458,8 @@ def synthesize_kie_page(
try:
d.text((0, 0), prediction["value"], font=font, fill=(0, 0, 0))
except UnicodeEncodeError:
# When character cannot be encoded, use its unidecode version
d.text((0, 0), unidecode(prediction["value"]), font=font, fill=(0, 0, 0))
# When character cannot be encoded, use its anyascii version
d.text((0, 0), anyascii(prediction["value"]), font=font, fill=(0, 0, 0))

# Colorize if draw_proba
if draw_proba:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ dependencies = [
"Pillow>=9.2.0",
"defusedxml>=0.7.0",
"mplcursors>=0.3",
"unidecode>=1.0.0",
"anyascii>=0.3.2",
"tqdm>=4.30.0",
]

Expand Down Expand Up @@ -145,6 +145,7 @@ implicit_reexport = false

[[tool.mypy.overrides]]
module = [
"anyascii.*",
"tensorflow.*",
"torchvision.*",
"PIL.*",
Expand Down
26 changes: 13 additions & 13 deletions tests/common/test_utils_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@


@pytest.mark.parametrize(
"gt, pred, raw, caseless, unidecode, unicase",
"gt, pred, raw, caseless, anyascii, unicase",
[
[["grass", "56", "True", "EUR"], ["grass", "56", "true", "€"], 0.5, 0.75, 0.75, 1],
[["éléphant", "ça"], ["elephant", "ca"], 0, 0, 1, 1],
],
)
def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
def test_text_match(gt, pred, raw, caseless, anyascii, unicase):
metric = metrics.TextMatch()
with pytest.raises(AssertionError):
metric.summary()
Expand All @@ -20,10 +20,10 @@ def test_text_match(gt, pred, raw, caseless, unidecode, unicase):
metric.update(["a", "b"], ["c"])

metric.update(gt, pred)
assert metric.summary() == dict(raw=raw, caseless=caseless, unidecode=unidecode, unicase=unicase)
assert metric.summary() == dict(raw=raw, caseless=caseless, anyascii=anyascii, unicase=unicase)

metric.reset()
assert metric.raw == metric.caseless == metric.unidecode == metric.unicase == metric.total == 0
assert metric.raw == metric.caseless == metric.anyascii == metric.unicase == metric.total == 0


@pytest.mark.parametrize(
Expand Down Expand Up @@ -208,8 +208,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
{"raw": 1, "caseless": 1, "anyascii": 1, "unicase": 1},
1,
],
[ # Bad match
Expand All @@ -218,8 +218,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]]],
[["elephant"]],
0.5,
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
{"raw": 0, "caseless": 0, "anyascii": 0, "unicase": 0},
1,
],
[ # Good match
Expand All @@ -228,8 +228,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5], [0.6, 0.6, 0.7, 0.7]]],
[["€", "e"]],
0.2,
{"raw": 0, "caseless": 0, "unidecode": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "unidecode": 0.5, "unicase": 0.5},
{"raw": 0, "caseless": 0, "anyascii": 1, "unicase": 1},
{"raw": 0, "caseless": 0, "anyascii": 0.5, "unicase": 0.5},
0.13,
],
[ # No preds on 2nd sample
Expand All @@ -238,8 +238,8 @@ def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mea
[[[0, 0, 0.5, 0.5]], None],
[["elephant"], []],
0.5,
{"raw": 0, "caseless": 0.5, "unidecode": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "unidecode": 0, "unicase": 1},
{"raw": 0, "caseless": 0.5, "anyascii": 0, "unicase": 0.5},
{"raw": 0, "caseless": 1, "anyascii": 0, "unicase": 1},
1,
],
],
Expand All @@ -256,7 +256,7 @@ def test_ocr_metric(gt_boxes, gt_words, pred_boxes, pred_words, iou_thresh, reca
assert _mean_iou == mean_iou
metric.reset()
assert metric.num_gts == metric.num_preds == metric.tot_iou == 0
assert metric.raw_matches == metric.caseless_matches == metric.unidecode_matches == metric.unicase_matches == 0
assert metric.raw_matches == metric.caseless_matches == metric.anyascii_matches == metric.unicase_matches == 0
# Shape check
with pytest.raises(AssertionError):
metric.update(
Expand Down

0 comments on commit 709404e

Please sign in to comment.