diff --git a/dedoc/scripts/calc_tesseract_benchmarks.py b/dedoc/scripts/calc_tesseract_benchmarks.py index fc76d63a..5d1d2288 100644 --- a/dedoc/scripts/calc_tesseract_benchmarks.py +++ b/dedoc/scripts/calc_tesseract_benchmarks.py @@ -11,9 +11,14 @@ from texttable import Texttable from dedoc.config import get_config +from dedoc.scripts.language_tool_correction import LanguageToolCorrector from dedoc.scripts.ocr_correction import init_correction_step, correction -USE_CORRECTION_OCR = False +WITHOUT_CORRECTION = 0 +SAGE_CORRECTION = 1 +LANGUAGE_TOOL_CORRECTION = 2 + +USE_CORRECTION_OCR = LANGUAGE_TOOL_CORRECTION def _call_tesseract(image: np.ndarray, language: str, psm: int = 3) -> str: @@ -188,8 +193,10 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, c os.makedirs(result_dir, exist_ok=True) corrector, corrected_path = None, None - if USE_CORRECTION_OCR: + if USE_CORRECTION_OCR == SAGE_CORRECTION: corrector, corrected_path = init_correction_step(cache_dir) + elif USE_CORRECTION_OCR == LANGUAGE_TOOL_CORRECTION: + corrector = LanguageToolCorrector() with zipfile.ZipFile(benchmark_data_path, "r") as arch_file: names_dirs = [member.filename for member in arch_file.infolist() if member.file_size > 0] @@ -235,7 +242,7 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, c # call correction step time_b = time.time() - if USE_CORRECTION_OCR: + if USE_CORRECTION_OCR == SAGE_CORRECTION: tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt") corrected_text = correction(corrector, text) correction_times.append(time.time() - time_b) @@ -243,6 +250,15 @@ def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str, c tmp_corrected_file.write(corrected_text) tmp_corrected_file.close() + calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path) + elif USE_CORRECTION_OCR == LANGUAGE_TOOL_CORRECTION: + tmp_corrected_path = os.path.join(corrected_path, f"{img_name}_ocr.txt") + corrected_text = corrector.correct(text) + correction_times.append(time.time() - time_b) + with open(tmp_corrected_path, "w") as tmp_corrected_file: + tmp_corrected_file.write(corrected_text) + tmp_corrected_file.close() + calculate_accuracy_script(tmp_gt_path, tmp_corrected_path, accuracy_path) else: calculate_accuracy_script(tmp_gt_path, tmp_ocr_path, accuracy_path) diff --git a/dedoc/scripts/language_tool_correction.py b/dedoc/scripts/language_tool_correction.py new file mode 100644 index 00000000..d0212e26 --- /dev/null +++ b/dedoc/scripts/language_tool_correction.py @@ -0,0 +1,9 @@ +import language_tool_python + + +class LanguageToolCorrector: + def __init__(self): + self.tool = language_tool_python.LanguageToolPublicAPI() + + def correct(self, text: str) -> str: + return self.tool.correct(text)