Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLDR-538 tesseract postprocessing #388

Merged
merged 4 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ uvicorn>=0.18.0,<=0.23.2
wget==3.2
xgbfir==0.3.1
xgboost>=1.1.1,<1.2.0
xlrd==1.2.0
xlrd==1.2.0
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved
textblob==0.17.1
359 changes: 359 additions & 0 deletions resources/benchmarks/tesseract_benchmark_sage-correction.txt

Large diffs are not rendered by default.

259 changes: 259 additions & 0 deletions resources/benchmarks/tesseract_benchmark_with_correction.txt
NastyBoget marked this conversation as resolved.
Show resolved Hide resolved

Large diffs are not rendered by default.

131 changes: 90 additions & 41 deletions scripts/calc_tesseract_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
import time
import zipfile
from tempfile import TemporaryDirectory
from typing import Dict, List, Tuple

import cv2
Expand All @@ -11,6 +11,14 @@
from texttable import Texttable

from dedoc.config import get_config
from scripts.ocr_correction import correction, init_correction_step
from scripts.text_blob_correction import TextBlobCorrector

WITHOUT_CORRECTION = ""
SAGE_CORRECTION = "_sage-correction"
TEXT_BLOB_CORRECTION = "_textblob-correction"

USE_CORRECTION_OCR = TEXT_BLOB_CORRECTION


def _call_tesseract(image: np.ndarray, language: str, psm: int = 3) -> str:
Expand Down Expand Up @@ -169,9 +177,28 @@ def __create_statistic_tables(statistics: dict, accuracy_values: List) -> Tuple[
return table_common, table_accuracy_per_image


def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -> Tuple[Texttable, Texttable]:
def calculate_accuracy_script(tmp_gt_path: str, tmp_prediction_path: str, accuracy_path: str) -> None:
# calculation accuracy build for Ubuntu from source https://github.com/eddieantonio/ocreval
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "accuracy"))
command = f"{accuracy_script_path} {tmp_gt_path} {tmp_prediction_path} >> {accuracy_path}"
os.system(command)


def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, cache_dir: str) -> Tuple[Texttable, Texttable]:
statistics = {}
accuracy_values = []
correction_times = []

result_dir = os.path.join(cache_dir, "result_ocr")
os.makedirs(result_dir, exist_ok=True)

corrector, corrected_path = None, None
if USE_CORRECTION_OCR == SAGE_CORRECTION:
corrector, corrected_path = init_correction_step(cache_dir)
elif USE_CORRECTION_OCR == TEXT_BLOB_CORRECTION:
corrector = TextBlobCorrector()
corrected_path = os.path.join(cache_dir, "result_corrected")
os.makedirs(corrected_path, exist_ok=True)

with zipfile.ZipFile(benchmark_data_path, "r") as arch_file:
names_dirs = [member.filename for member in arch_file.infolist() if member.file_size > 0]
Expand All @@ -191,41 +218,61 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -
gt_path = os.path.join(base_zip, dataset_name, "gts", f"{base_name}.txt")
imgs_path = os.path.join(base_zip, dataset_name, "imgs", img_name)
accuracy_path = os.path.join(cache_dir_accuracy, f"{dataset_name}_{base_name}_accuracy.txt")

with TemporaryDirectory() as tmpdir:
tmp_gt_path = os.path.join(tmpdir, "tmp_gt.txt")
tmp_ocr_path = os.path.join(tmpdir, "tmp_ocr.txt")

try:
with arch_file.open(gt_path) as gt_file, open(tmp_gt_path, "wb") as tmp_gt_file, open(tmp_ocr_path, "w") as tmp_ocr_file:

gt_text = gt_file.read().decode("utf-8")
word_cnt = len(gt_text.split())

tmp_gt_file.write(gt_text.encode()) # extraction gt from zip
tmp_gt_file.flush()

arch_file.extract(imgs_path, tmpdir)
image = cv2.imread(tmpdir + "/" + imgs_path)

# call ocr
psm = 6 if dataset_name == "english-words" else 4
text = _call_tesseract(image, "rus+eng", psm=psm)
tmp_ocr_file.write(text)
tmp_ocr_file.flush()

# calculation accuracy build for Ubuntu from source https://github.com/eddieantonio/ocreval
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "accuracy"))
command = f"{accuracy_script_path} {tmp_gt_path} {tmp_ocr_path} >> {accuracy_path}"
os.system(command)

statistics = _update_statistics_by_dataset(statistics, dataset_name, accuracy_path, word_cnt)
accuracy_values.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])

except Exception as ex:
print(ex)
print("If you have problems with libutf8proc.so.2, try the command: `apt install -y libutf8proc-dev`")

if os.path.exists(accuracy_path):
os.remove(accuracy_path)

tmp_gt_path = os.path.join(result_dir, f"{img_name}_gt.txt")
tmp_ocr_path = os.path.join(result_dir, f"{img_name}_ocr.txt")

try:
with arch_file.open(gt_path) as gt_file, open(tmp_gt_path, "wb") as tmp_gt_file, open(tmp_ocr_path, "w") as tmp_ocr_file:

gt_text = gt_file.read().decode("utf-8")
word_cnt = len(gt_text.split())

tmp_gt_file.write(gt_text.encode()) # extraction gt from zip
tmp_gt_file.close()

arch_file.extract(imgs_path, result_dir)
image = cv2.imread(result_dir + "/" + imgs_path)

# call ocr
psm = 6 if dataset_name == "english-words" else 4
text = _call_tesseract(image, "rus+eng", psm=psm)
tmp_ocr_file.write(text)
tmp_ocr_file.close()

# call correction step
time_b = time.time()
if USE_CORRECTION_OCR == SAGE_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = correction(corrector, text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
elif USE_CORRECTION_OCR == TEXT_BLOB_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = corrector.correct(text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
else:
calculate_accuracy_script(tmp_gt_path, tmp_ocr_path, accuracy_path)

statistics = _update_statistics_by_dataset(statistics, dataset_name, accuracy_path, word_cnt)
accuracy_values.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])

except Exception as ex:
print(ex)
print("If you have problems with libutf8proc.so.2, try the command: `apt install -y libutf8proc-dev`")

print(f"Time mean correction ocr = {np.array(correction_times).mean()}")
table_common, table_accuracy_per_image = __create_statistic_tables(statistics, accuracy_values)
return table_common, table_accuracy_per_image

Expand All @@ -240,18 +287,20 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -

benchmark_data_path = os.path.join(cache_dir, f"{base_zip}.zip")
if not os.path.isfile(benchmark_data_path):
wget.download("https://at.ispras.ru/owncloud/index.php/s/HqKt53BWmR8nCVG/download", benchmark_data_path)
wget.download("https://at.ispras.ru/owncloud/index.php/s/wMyKioKInYITpYT", benchmark_data_path)
print(f"Benchmark data downloaded to {benchmark_data_path}")
else:
print(f"Use cached benchmark data from {benchmark_data_path}")
assert os.path.isfile(benchmark_data_path)

table_common, table_accuracy_per_image = __calculate_ocr_reports(cache_dir_accuracy, benchmark_data_path)
table_common, table_accuracy_per_image = __calculate_ocr_reports(cache_dir_accuracy, benchmark_data_path, cache_dir)

table_errors = __get_summary_symbol_error(path_reports=cache_dir_accuracy)

with open(os.path.join(output_dir, "tesseract_benchmark.txt"), "w") as res_file:
res_file.write(f"Tesseract version is {pytesseract.get_tesseract_version()}\nTable 1 - Accuracy for each file\n")
with open(os.path.join(output_dir, f"tesseract_benchmark{USE_CORRECTION_OCR}.txt"), "w") as res_file:
res_file.write(f"Tesseract version is {pytesseract.get_tesseract_version()}\n")
res_file.write(f"Correction step: {USE_CORRECTION_OCR}\n")
res_file.write(f"\nTable 1 - Accuracy for each file\n")
res_file.write(table_accuracy_per_image.draw())
res_file.write(f"\n\nTable 2 - AVG by each type of symbols:\n")
res_file.write(table_common.draw())
Expand Down
43 changes: 43 additions & 0 deletions scripts/ocr_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
from typing import Tuple

import torch
from sage.spelling_correction.corrector import Corrector
from sage.spelling_correction import AvailableCorrectors
from sage.spelling_correction import RuM2M100ModelForSpellingCorrection

'''
Install sage library (for ocr correction step):
git clone https://github.com/ai-forever/sage.git
cd sage
pip install .
pip install -r requirements.txt

Note: sage use 5.2 Gb GPU ......
'''
USE_GPU = True


def correction(model: Corrector, ocr_text: str) -> str:

corrected_lines = []
for line in ocr_text.split("\n"):
corrected_lines.append(model.correct(line)[0])
corrected_text = "\n".join(corrected_lines)

return corrected_text


def init_correction_step(cache_dir: str) -> Tuple[Corrector, str]:

corrected_path = os.path.join(cache_dir, "result_corrected")
os.makedirs(corrected_path, exist_ok=True)
corrector = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_1B.value) # 4.49 Gb model (pytorch_model.bin)
if torch.cuda.is_available() and USE_GPU:
corrector.model.to(torch.device("cuda:0"))
print("use CUDA")
else:
print("use CPU")
return corrector, corrected_path


9 changes: 9 additions & 0 deletions scripts/text_blob_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from textblob import TextBlob


class TextBlobCorrector:
def __init__(self):
return

def correct(self, text: str) -> str:
return str(TextBlob(text).correct())
Loading