Skip to content

Commit

Permalink
TLDR-538 ocr correction scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
oksidgy authored and NastyBoget committed Jan 26, 2024
1 parent 7b20361 commit 1403898
Show file tree
Hide file tree
Showing 9 changed files with 771 additions and 41 deletions.
Empty file.
9 changes: 9 additions & 0 deletions dedoc/scripts/text_blob_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from textblob import TextBlob


class TextBlobCorrector:
def __init__(self):
return

def correct(self, text: str) -> str:
return str(TextBlob(text).correct())
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ uvicorn>=0.18.0,<=0.23.2
wget==3.2
xgbfir==0.3.1
xgboost>=1.1.1,<1.2.0
xlrd==1.2.0
xlrd==1.2.0
textblob==0.17.1
359 changes: 359 additions & 0 deletions resources/benchmarks/tesseract_benchmark_sage-correction.txt

Large diffs are not rendered by default.

259 changes: 259 additions & 0 deletions resources/benchmarks/tesseract_benchmark_with_correction.txt

Large diffs are not rendered by default.

130 changes: 90 additions & 40 deletions scripts/calc_tesseract_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
import time
import zipfile
from tempfile import TemporaryDirectory
from typing import Dict, List, Tuple
Expand All @@ -11,6 +12,14 @@
from texttable import Texttable

from dedoc.config import get_config
from dedoc.scripts.text_blob_correction import TextBlobCorrector
from scripts.ocr_correction import correction, init_correction_step

WITHOUT_CORRECTION = ""
SAGE_CORRECTION = "_sage-correction"
TEXT_BLOB_CORRECTION = "_textblob-correction"

USE_CORRECTION_OCR = TEXT_BLOB_CORRECTION


def _call_tesseract(image: np.ndarray, language: str, psm: int = 3) -> str:
Expand Down Expand Up @@ -169,9 +178,28 @@ def __create_statistic_tables(statistics: dict, accuracy_values: List) -> Tuple[
return table_common, table_accuracy_per_image


def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -> Tuple[Texttable, Texttable]:
def calculate_accuracy_script(tmp_gt_path: str, tmp_prediction_path: str, accuracy_path: str) -> None:
# calculation accuracy build for Ubuntu from source https://github.com/eddieantonio/ocreval
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "accuracy"))
command = f"{accuracy_script_path} {tmp_gt_path} {tmp_prediction_path} >> {accuracy_path}"
os.system(command)


def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, cache_dir: str) -> Tuple[Texttable, Texttable]:
statistics = {}
accuracy_values = []
correction_times = []

result_dir = os.path.join(cache_dir, "result_ocr")
os.makedirs(result_dir, exist_ok=True)

corrector, corrected_path = None, None
if USE_CORRECTION_OCR == SAGE_CORRECTION:
corrector, corrected_path = init_correction_step(cache_dir)
elif USE_CORRECTION_OCR == TEXT_BLOB_CORRECTION:
corrector = TextBlobCorrector()
corrected_path = os.path.join(cache_dir, "result_corrected")
os.makedirs(corrected_path, exist_ok=True)

with zipfile.ZipFile(benchmark_data_path, "r") as arch_file:
names_dirs = [member.filename for member in arch_file.infolist() if member.file_size > 0]
Expand All @@ -191,41 +219,61 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -
gt_path = os.path.join(base_zip, dataset_name, "gts", f"{base_name}.txt")
imgs_path = os.path.join(base_zip, dataset_name, "imgs", img_name)
accuracy_path = os.path.join(cache_dir_accuracy, f"{dataset_name}_{base_name}_accuracy.txt")

with TemporaryDirectory() as tmpdir:
tmp_gt_path = os.path.join(tmpdir, "tmp_gt.txt")
tmp_ocr_path = os.path.join(tmpdir, "tmp_ocr.txt")

try:
with arch_file.open(gt_path) as gt_file, open(tmp_gt_path, "wb") as tmp_gt_file, open(tmp_ocr_path, "w") as tmp_ocr_file:

gt_text = gt_file.read().decode("utf-8")
word_cnt = len(gt_text.split())

tmp_gt_file.write(gt_text.encode()) # extraction gt from zip
tmp_gt_file.flush()

arch_file.extract(imgs_path, tmpdir)
image = cv2.imread(tmpdir + "/" + imgs_path)

# call ocr
psm = 6 if dataset_name == "english-words" else 4
text = _call_tesseract(image, "rus+eng", psm=psm)
tmp_ocr_file.write(text)
tmp_ocr_file.flush()

# calculation accuracy build for Ubuntu from source https://github.com/eddieantonio/ocreval
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "accuracy"))
command = f"{accuracy_script_path} {tmp_gt_path} {tmp_ocr_path} >> {accuracy_path}"
os.system(command)

statistics = _update_statistics_by_dataset(statistics, dataset_name, accuracy_path, word_cnt)
accuracy_values.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])

except Exception as ex:
print(ex)
print("If you have problems with libutf8proc.so.2, try the command: `apt install -y libutf8proc-dev`")

if os.path.exists(accuracy_path):
os.remove(accuracy_path)

tmp_gt_path = os.path.join(result_dir, f"{img_name}_gt.txt")
tmp_ocr_path = os.path.join(result_dir, f"{img_name}_ocr.txt")

try:
with arch_file.open(gt_path) as gt_file, open(tmp_gt_path, "wb") as tmp_gt_file, open(tmp_ocr_path, "w") as tmp_ocr_file:

gt_text = gt_file.read().decode("utf-8")
word_cnt = len(gt_text.split())

tmp_gt_file.write(gt_text.encode()) # extraction gt from zip
tmp_gt_file.close()

arch_file.extract(imgs_path, result_dir)
image = cv2.imread(result_dir + "/" + imgs_path)

# call ocr
psm = 6 if dataset_name == "english-words" else 4
text = _call_tesseract(image, "rus+eng", psm=psm)
tmp_ocr_file.write(text)
tmp_ocr_file.close()

# call correction step
time_b = time.time()
if USE_CORRECTION_OCR == SAGE_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = correction(corrector, text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
elif USE_CORRECTION_OCR == TEXT_BLOB_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = corrector.correct(text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
else:
calculate_accuracy_script(tmp_gt_path, tmp_ocr_path, accuracy_path)

statistics = _update_statistics_by_dataset(statistics, dataset_name, accuracy_path, word_cnt)
accuracy_values.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])

except Exception as ex:
print(ex)
print("If you have problems with libutf8proc.so.2, try the command: `apt install -y libutf8proc-dev`")

print(f"Time mean correction ocr = {np.array(correction_times).mean()}")
table_common, table_accuracy_per_image = __create_statistic_tables(statistics, accuracy_values)
return table_common, table_accuracy_per_image

Expand All @@ -240,18 +288,20 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -

benchmark_data_path = os.path.join(cache_dir, f"{base_zip}.zip")
if not os.path.isfile(benchmark_data_path):
wget.download("https://at.ispras.ru/owncloud/index.php/s/HqKt53BWmR8nCVG/download", benchmark_data_path)
wget.download("https://at.ispras.ru/owncloud/index.php/s/wMyKioKInYITpYT", benchmark_data_path)
print(f"Benchmark data downloaded to {benchmark_data_path}")
else:
print(f"Use cached benchmark data from {benchmark_data_path}")
assert os.path.isfile(benchmark_data_path)

table_common, table_accuracy_per_image = __calculate_ocr_reports(cache_dir_accuracy, benchmark_data_path)
table_common, table_accuracy_per_image = __calculate_ocr_reports(cache_dir_accuracy, benchmark_data_path, cache_dir)

table_errors = __get_summary_symbol_error(path_reports=cache_dir_accuracy)

with open(os.path.join(output_dir, "tesseract_benchmark.txt"), "w") as res_file:
res_file.write(f"Tesseract version is {pytesseract.get_tesseract_version()}\nTable 1 - Accuracy for each file\n")
with open(os.path.join(output_dir, f"tesseract_benchmark{USE_CORRECTION_OCR}.txt"), "w") as res_file:
res_file.write(f"Tesseract version is {pytesseract.get_tesseract_version()}\n")
res_file.write(f"Correction step: {USE_CORRECTION_OCR}\n")
res_file.write(f"\nTable 1 - Accuracy for each file\n")
res_file.write(table_accuracy_per_image.draw())
res_file.write(f"\n\nTable 2 - AVG by each type of symbols:\n")
res_file.write(table_common.draw())
Expand Down
Empty file.
43 changes: 43 additions & 0 deletions scripts/ocr_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
from typing import Tuple

import torch
from sage.spelling_correction.corrector import Corrector
from sage.spelling_correction import AvailableCorrectors
from sage.spelling_correction import RuM2M100ModelForSpellingCorrection

'''
Install sage library (for ocr correction step):
git clone https://github.com/ai-forever/sage.git
cd sage
pip install .
pip install -r requirements.txt
Note: sage use 5.2 Gb GPU ......
'''
USE_GPU = True


def correction(model: Corrector, ocr_text: str) -> str:

corrected_lines = []
for line in ocr_text.split("\n"):
corrected_lines.append(model.correct(line)[0])
corrected_text = "\n".join(corrected_lines)

return corrected_text


def init_correction_step(cache_dir: str) -> Tuple[Corrector, str]:

corrected_path = os.path.join(cache_dir, "result_corrected")
os.makedirs(corrected_path, exist_ok=True)
corrector = RuM2M100ModelForSpellingCorrection.from_pretrained(AvailableCorrectors.m2m100_1B.value) # 4.49 Gb model (pytorch_model.bin)
if torch.cuda.is_available() and USE_GPU:
corrector.model.to(torch.device("cuda:0"))
print("use CUDA")
else:
print("use CPU")
return corrector, corrected_path


9 changes: 9 additions & 0 deletions scripts/text_blob_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from textblob import TextBlob


class TextBlobCorrector:
def __init__(self):
return

def correct(self, text: str) -> str:
return str(TextBlob(text).correct())

0 comments on commit 1403898

Please sign in to comment.