Skip to content

Commit

Permalink
TLDR-538 tesseract trustai (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
oksidgy authored Dec 6, 2023
1 parent 1fefda5 commit f3ec0e5
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 31 deletions.
Binary file added dedoc/scripts/accsum
Binary file not shown.
142 changes: 116 additions & 26 deletions dedoc/scripts/calc_tesseract_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import zipfile
from tempfile import TemporaryDirectory
from typing import Dict, List
from typing import Dict, List, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -79,24 +79,99 @@ def _get_avg_by_dataset(statistics: Dict, dataset: str) -> List:
_get_avg(statistics[dataset]["Accuracy"])]


if __name__ == "__main__":
base_zip = "data_tesseract_benchmarks"
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "benchmarks"))
cache_dir = os.path.join(get_config()["intermediate_data_path"], "tesseract_data")
os.makedirs(cache_dir, exist_ok=True)
benchmark_data_path = os.path.join(cache_dir, f"{base_zip}.zip")
def __parse_symbol_info(lines: List[str]) -> Tuple[List, int]:
symbols_info = []
matched_symbols = [(line_num, line) for line_num, line in enumerate(lines) if "Count Missed %Right" in line][-1]
start_block_line = matched_symbols[0]

if not os.path.isfile(benchmark_data_path):
wget.download("https://at.ispras.ru/owncloud/index.php/s/HqKt53BWmR8nCVG/download", benchmark_data_path)
print(f"Benchmark data downloaded to {benchmark_data_path}")
else:
print(f"Use cached benchmark data from {benchmark_data_path}")
assert os.path.isfile(benchmark_data_path)
for line in lines[start_block_line + 1:]:
# example line: "1187 11 99.07 {<\n>}"
row_values = [value.strip() for value in re.findall(r"\d+.\d*|{\S+|\W+}", line)]
row_values[-1] = row_values[-1][1:-1] # get symbol value
symbols_info.append(row_values)
# Sort errors
symbols_info = sorted(symbols_info, key=lambda row: int(row[1]), reverse=True) # by missed

return symbols_info, start_block_line


def __parse_ocr_errors(lines: List[str]) -> List:
ocr_errors = []
matched_errors = [(line_num, line) for line_num, line in enumerate(lines) if "Errors Marked Correct-Generated" in line][0]
for num, line in enumerate(lines[matched_errors[0] + 1:]):
# example line: " 2 0 { 6}-{б}"
errors = re.findall(r"(\d+)", line)[0]
chars = re.findall(r"{(.*)}-{(.*)}", line)[0]
ocr_errors.append([errors, chars[0], chars[1]])

return ocr_errors


def __get_summary_symbol_error(path_reports: str) -> Texttable:
# 1 - call accsum for get summary of all reports
accuracy_script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "accsum"))

if os.path.exists(f"{path_reports}/../accsum_report.txt"):
os.remove(f"{path_reports}/../accsum_report.txt")

file_reports = " ".join([os.path.join(path_reports, f) for f in os.listdir(path_reports) if os.path.isfile(os.path.join(path_reports, f))])

command = f"{accuracy_script_path} {file_reports} >> {path_reports}/../accsum_report.txt"
os.system(command)
accsum_report_path = os.path.join(path_reports, "../accsum_report.txt")

# 2 - parse report info
with open(accsum_report_path, "r") as f:
lines = f.readlines()

symbols_info, start_symbol_block_line = __parse_symbol_info(lines)
ocr_errors = __parse_ocr_errors(lines[:start_symbol_block_line - 1])

# 3 - calculate ocr errors according to a symbol
ocr_errors_by_symbol = {}
for symbol_info in symbols_info:
ocr_errors_by_symbol[symbol_info[-1]] = []
for ocr_err in ocr_errors:
if ocr_err[-1] == "" or len(ocr_err[-2]) > 3 or len(ocr_err[-1]) > 3: # to ignore errors with long text (len > 3) or without text
continue
if symbol_info[-1] in ocr_err[-2]:
ocr_errors_by_symbol[symbol_info[-1]].append(f"{ocr_err[0]} & <{ocr_err[1]}> -> <{ocr_err[2]}>")

# 4 - create table with OCR errors
ocr_err_by_symbol_table = Texttable()
title = [["Symbol", "Cnt Errors & Correct-Generated"]]
ocr_err_by_symbol_table.add_rows(title)
for symbol, value in ocr_errors_by_symbol.items():
if len(value) != 0:
ocr_err_by_symbol_table.add_row([symbol, value])

return ocr_err_by_symbol_table


def __create_statistic_tables(statistics: dict, accuracy_values: List) -> Tuple[Texttable, Texttable]:
accs = [["Dataset", "Image name", "--psm", "Amount of words", "Accuracy OCR"]]
accs_common = [["Dataset", "ASCII_Spacing_Chars", "ASCII_Special_Symbols", "ASCII_Digits",
"ASCII_Uppercase_Chars", "Latin1_Special_Symbols", "Cyrillic", "Amount of words", "AVG Accuracy"]]

table_accuracy_per_image = Texttable()
accs.extend(accuracy_values)
table_accuracy_per_image.add_rows(accs)

# calculating average accuracy for each data set
table_common = Texttable()

for dataset_name in sorted(statistics.keys()):
row = [dataset_name]
row.extend(_get_avg_by_dataset(statistics, dataset_name))
accs_common.append(row)
table_common.add_rows(accs_common)

return table_common, table_accuracy_per_image


def __calculate_ocr_reports(cache_dir_accuracy: str, benchmark_data_path: str) -> Tuple[Texttable, Texttable]:
statistics = {}
accuracy_values = []

with zipfile.ZipFile(benchmark_data_path, "r") as arch_file:
names_dirs = [member.filename for member in arch_file.infolist() if member.file_size > 0]
Expand All @@ -115,7 +190,7 @@ def _get_avg_by_dataset(statistics: Dict, dataset: str) -> List:

gt_path = os.path.join(base_zip, dataset_name, "gts", f"{base_name}.txt")
imgs_path = os.path.join(base_zip, dataset_name, "imgs", img_name)
accuracy_path = os.path.join(cache_dir, f"{dataset_name}_{base_name}_accuracy.txt")
accuracy_path = os.path.join(cache_dir_accuracy, f"{dataset_name}_{base_name}_accuracy.txt")

with TemporaryDirectory() as tmpdir:
tmp_gt_path = os.path.join(tmpdir, "tmp_gt.txt")
Expand Down Expand Up @@ -145,30 +220,45 @@ def _get_avg_by_dataset(statistics: Dict, dataset: str) -> List:
os.system(command)

statistics = _update_statistics_by_dataset(statistics, dataset_name, accuracy_path, word_cnt)
accs.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])
accuracy_values.append([dataset_name, base_name, psm, word_cnt, statistics[dataset_name]["Accuracy"][-1]])

except Exception as ex:
print(ex)
print("If you have problems with libutf8proc.so.2, try the command: `apt install -y libutf8proc-dev`")

table_aacuracy_per_image = Texttable()
table_aacuracy_per_image.add_rows(accs)
table_common, table_accuracy_per_image = __create_statistic_tables(statistics, accuracy_values)
return table_common, table_accuracy_per_image

# calculating average accuracy for each data set
table_common = Texttable()

for dataset_name in sorted(statistics.keys()):
row = [dataset_name]
row.extend(_get_avg_by_dataset(statistics, dataset_name))
accs_common.append(row)
table_common.add_rows(accs_common)
if __name__ == "__main__":
base_zip = "data_tesseract_benchmarks"
output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "resources", "benchmarks"))
cache_dir = os.path.join(get_config()["intermediate_data_path"], "tesseract_data")
os.makedirs(cache_dir, exist_ok=True)
cache_dir_accuracy = os.path.join(cache_dir, "accuracy")
os.makedirs(cache_dir_accuracy, exist_ok=True)

benchmark_data_path = os.path.join(cache_dir, f"{base_zip}.zip")
if not os.path.isfile(benchmark_data_path):
wget.download("https://at.ispras.ru/owncloud/index.php/s/HqKt53BWmR8nCVG/download", benchmark_data_path)
print(f"Benchmark data downloaded to {benchmark_data_path}")
else:
print(f"Use cached benchmark data from {benchmark_data_path}")
assert os.path.isfile(benchmark_data_path)

table_common, table_accuracy_per_image = __calculate_ocr_reports(cache_dir_accuracy, benchmark_data_path)

table_errors = __get_summary_symbol_error(path_reports=cache_dir_accuracy)

with open(os.path.join(output_dir, "tesseract_benchmark.txt"), "w") as res_file:
res_file.write(f"Tesseract version is {pytesseract.get_tesseract_version()}\nTable 1 - Accuracy for each file\n")
res_file.write(table_aacuracy_per_image.draw())
res_file.write(table_accuracy_per_image.draw())
res_file.write(f"\n\nTable 2 - AVG by each type of symbols:\n")
res_file.write(table_common.draw())
res_file.write(f"\n\nTable 3 -OCR error by symbol:\n")
res_file.write(table_errors.draw())

print(f"Tesseract version is {pytesseract.get_tesseract_version()}")
print(table_aacuracy_per_image.draw())
print(table_accuracy_per_image.draw())
print(table_common.draw())
print(table_errors.draw())
2 changes: 1 addition & 1 deletion dedoc/train_dataset/trainer/errors_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def save_errors(self, error_cnt: Counter, errors_uids: List[str], csv_path: str,
with open(path_file) as file:
lines = file.readlines()
lines_cnt = Counter(lines)
lines.sort(key=lambda l: (-lines_cnt[l], l))
lines.sort(key=lambda value: (-lines_cnt[value], value))
path_out = os.path.join(self.errors_path, f"{int(1000 * len(lines) / errors_total_num):04d}_{file_name}")

with open(path_out, "w") as file_out:
Expand Down
Loading

0 comments on commit f3ec0e5

Please sign in to comment.