Skip to content

Commit

Permalink
Add language tool corrector
Browse files Browse the repository at this point in the history
  • Loading branch information
sunveil committed Dec 20, 2023
1 parent d29f21d commit 6739ae0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
22 changes: 19 additions & 3 deletions dedoc/scripts/calc_tesseract_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@
from texttable import Texttable

from dedoc.config import get_config
from dedoc.scripts.language_tool_correction import LanguageToolCorrector
from dedoc.scripts.ocr_correction import init_correction_step, correction

USE_CORRECTION_OCR = False
WITHOUT_CORRECTION = 0
SAGE_CORRECTION = 1
LANGUAGE_TOOL_CORRECTION = 2

USE_CORRECTION_OCR = LANGUAGE_TOOL_CORRECTION


def _call_tesseract(image: np.ndarray, language: str, psm: int = 3) -> str:
Expand Down Expand Up @@ -188,8 +193,10 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, c
os.makedirs(result_dir, exist_ok=True)

corrector, corrected_path = None, None
if USE_CORRECTION_OCR:
if USE_CORRECTION_OCR == SAGE_CORRECTION:
corrector, corrected_path = init_correction_step(cache_dir)
elif USE_CORRECTION_OCR == LANGUAGE_TOOL_CORRECTION:
corrector = LanguageToolCorrector()

with zipfile.ZipFile(benchmark_data_path, "r") as arch_file:
names_dirs = [member.filename for member in arch_file.infolist() if member.file_size > 0]
Expand Down Expand Up @@ -235,14 +242,23 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, c

# call correction step
time_b = time.time()
if USE_CORRECTION_OCR:
if USE_CORRECTION_OCR == SAGE_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = correction(corrector, text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
elif USE_CORRECTION_OCR == LANGUAGE_TOOL_CORRECTION:
tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt")
corrected_text = corrector.correct(text)
correction_times.append(time.time() - time_b)
with open(tmp_corrected_path, "w") as tmp_corrected_file:
tmp_corrected_file.write(corrected_text)
tmp_corrected_file.close()

calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path)
else:
calculate_accuracy_script(tmp_gt_path, tmp_ocr_path, accuracy_path)
Expand Down
9 changes: 9 additions & 0 deletions dedoc/scripts/language_tool_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import language_tool_python


class LanguageToolCorrector:
def __init__(self):
self.tool = language_tool_python.LanguageToolPublicAPI()

def correct(self, text: str) -> str:
return self.tool.correct(text)

0 comments on commit 6739ae0

Please sign in to comment.