Skip to content

Commit

Permalink
TLDR-585 added TEDS table benchmark (#398)
Browse files Browse the repository at this point in the history
* TLDR-585 added TEDS table benchmark

* TLDR-585 fixed after review

* TLDR-585 fixed bug, include cells's content in metric

* TLDR-591 added table generation benchmark

* TLDR-585 fixed after review
  • Loading branch information
oksidgy authored Jan 29, 2024
1 parent ca7442d commit 189c267
Show file tree
Hide file tree
Showing 8 changed files with 862 additions and 11 deletions.
4 changes: 2 additions & 2 deletions dedoc/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def json2html(text: str, paragraph: TreeNode, tables: Optional[List[Table]], tab
if tables is not None and len(tables) > 0:
text += "<h3> Tables: </h3>"
for table in tables:
text += __table2html(table, table2id)
text += table2html(table, table2id)
text += "<p>&nbsp;</p>"
return text

Expand Down Expand Up @@ -201,7 +201,7 @@ def __annotations2html(paragraph: TreeNode, table2id: Dict[str, int]) -> str:
return text.replace("\n", "<br>")


def __table2html(table: Table, table2id: Dict[str, int]) -> str:
def table2html(table: Table, table2id: Dict[str, int]) -> str:
uid = table.metadata.uid
text = f"<h4> table {table2id[uid]}:</h4>"
text += f'<table border="1" id={uid} style="border-collapse: collapse; width: 100%;">\n<tbody>\n'
Expand Down
6 changes: 6 additions & 0 deletions dedoc/readers/pdf_reader/data_classes/tables/scantable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from dedocutils.data_structures import BBox

from dedoc.data_structures import CellWithMeta, Table, TableMetadata
from dedoc.readers.pdf_reader.data_classes.tables.cell import Cell
from dedoc.readers.pdf_reader.data_classes.tables.location import Location

Expand All @@ -27,6 +28,11 @@ def extended(self, table: "ScanTable") -> None:
# extend order
self.order = max(self.order, table.order)

def to_table(self) -> Table:
metadata = TableMetadata(page_id=self.page_number, uid=self.name, rotated_angle=self.location.rotated_angle)
cells_with_meta = [[CellWithMeta.create_from_cell(cell) for cell in row] for row in self.matrix_cells]
return Table(metadata=metadata, cells=cells_with_meta)

@staticmethod
def get_cells_text(attr_cells: List[List[Cell]]) -> List[List[str]]:
attrs = []
Expand Down
10 changes: 1 addition & 9 deletions dedoc/readers/pdf_reader/pdf_base_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@
import dedoc.utils.parameter_utils as param_utils
from dedoc.attachments_extractors.concrete_attachments_extractors.pdf_attachments_extractor import PDFAttachmentsExtractor
from dedoc.common.exceptions.bad_file_error import BadFileFormatError
from dedoc.data_structures.cell_with_meta import CellWithMeta
from dedoc.data_structures.line_with_meta import LineWithMeta
from dedoc.data_structures.table import Table
from dedoc.data_structures.table_metadata import TableMetadata
from dedoc.data_structures.unstructured_document import UnstructuredDocument
from dedoc.extensions import recognized_extensions, recognized_mimes
from dedoc.readers.base_reader import BaseReader
Expand Down Expand Up @@ -92,12 +89,7 @@ def read(self, file_path: str, parameters: Optional[dict] = None) -> Unstructure
)

lines, scan_tables, attachments, warnings, other_fields = self._parse_document(file_path, params_for_parse)
tables = []
for scan_table in scan_tables:
metadata = TableMetadata(page_id=scan_table.page_number, uid=scan_table.name, rotated_angle=scan_table.location.rotated_angle)
cells_with_meta = [[CellWithMeta.create_from_cell(cell) for cell in row] for row in scan_table.matrix_cells]
table = Table(metadata=metadata, cells=cells_with_meta)
tables.append(table)
tables = [scan_table.to_table() for scan_table in scan_tables]

if self._can_contain_attachements(file_path) and self.attachment_extractor.with_attachments(parameters):
attachments += self.attachment_extractor.extract(file_path=file_path, parameters=parameters)
Expand Down
16 changes: 16 additions & 0 deletions resources/benchmarks/table_benchmark.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"mode_metric_structure_only": false,
"mean": 0.9468374367023571,
"images": {
"example_with_table0_0.png": 0.9525583036909738,
"example_with_table0_1.png": 0.9264351862896008,
"example_with_table6.png": 0.989010989010989,
"example_with_table4.jpg": 0.908436211832951,
"example_with_table17.jpg": 0.8078952936402488,
"example_with_table_hor_vert_union.png": 0.9896091617933723,
"example_with_table1.png": 0.9781560283687943,
"example_with_table_horizontal_union.jpg": 0.9925757575757576,
"example_with_table3.png": 0.9778008866078716,
"example_with_table5.png": 0.9458965482130129
}
}
506 changes: 506 additions & 0 deletions resources/benchmarks/table_benchmark_on_generated_data.json

Large diffs are not rendered by default.

167 changes: 167 additions & 0 deletions scripts/benchmark_table/benchmark_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import zipfile
from pathlib import Path
import json
import pprint
from typing import Optional, List
import numpy as np
import wget

from dedoc.api.api_utils import table2html
from dedoc.config import get_config
from dedoc.readers import PdfImageReader
from dedoc.readers.pdf_reader.pdf_image_reader.table_recognizer.table_recognizer import TableRecognizer
from scripts.benchmark_table.metric import TEDS

path_result = Path(__file__).parent / ".." / ".." / "resources" / "benchmarks"
path_result.absolute().mkdir(parents=True, exist_ok=True)

table_recognizer = TableRecognizer(config=get_config())
image_reader = PdfImageReader(config=get_config())

GENERATED_BENCHMARK = "on_generated_data"
OURDATA_BENCHMARK = "on_our_data"
TYPE_BENCHMARK = OURDATA_BENCHMARK


def call_metric(pred_json: dict, true_json: dict, structure_only: bool = False, ignore_nodes: Optional[List] = None) -> dict:
teds = TEDS(structure_only=structure_only, ignore_nodes=ignore_nodes)
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)

return scores


def get_tables(image_path: Path) -> str:
document = image_reader.read(str(image_path))

for table in document.tables:
table.metadata.uid = "test_id"
table2id = {"test_id": 0}
html_tables = [table2html(table, table2id) for table in document.tables]

# TODO: while works with one table in an image
return html_tables[0]


def make_predict_json(data_path: Path) -> dict:
predict_json = {}
for pathname in Path.iterdir(data_path):
print(pathname)

predict_json[pathname.name] = {"html": "<html><body>" + get_tables(pathname) + "</body></html>"}

return predict_json


def download_dataset(data_dir: Path, name_zip: str, url: str) -> None:
if Path.exists(data_dir):
print(f"Use cached benchmark data from {data_dir}")
return

data_dir.mkdir(parents=True, exist_ok=True)
pdfs_zip_path = data_dir / name_zip
wget.download(url, str(data_dir))

with zipfile.ZipFile(pdfs_zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
pdfs_zip_path.unlink()

print(f"Benchmark data downloaded to {data_dir}")


def prediction(path_pred: Path, path_images: Path) -> dict:
pred_json = make_predict_json(path_images)
with path_pred.open("w") as fd:
json.dump(pred_json, fd, indent=2, ensure_ascii=False)

return pred_json


def benchmark_on_our_data() -> dict:
data_dir = Path(get_config()["intermediate_data_path"]) / "benchmark_table_data"
path_images = data_dir / "images"
path_gt = data_dir / "gt.json"
path_pred = data_dir / "pred.json"
download_dataset(data_dir,
name_zip="benchmark_table_data.zip",
url="https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download")

mode_metric_structure_only = False

with open(path_gt, "r") as fp:
gt_json = json.load(fp)
'''
Creating base html (based on method predictions for future labeling)
path_images = data_dir / "images_tmp"
pred_json = prediction("gt_tmp.json", path_images)
'''
pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=gt_json, structure_only=mode_metric_structure_only)

result = dict()
result["mode_metric_structure_only"] = mode_metric_structure_only
result["mean"] = np.mean([score for score in scores.values()])
result["images"] = scores

return result


def benchmark_on_generated_table() -> dict:
"""
Generated data from https://github.com/hassan-mahmood/TIES_DataGeneration
Article generation information https://arxiv.org/pdf/1905.13391.pdf
Note: generate the 1st table tape category
Note: don't use header table tag <th>, replacing on <td> tag
Note: all generated data (four categories) you can download from
TODO: some tables have a low quality. Should to trace the reason.
All generated data (all categories) we can download from https://at.ispras.ru/owncloud/index.php/s/cjpCIR7I0G4JzZU
"""

data_dir = Path(get_config()["intermediate_data_path"]) / "visualizeimgs" / "category1"
path_images = data_dir / "img_500"
path_gt = data_dir / "html_500"
download_dataset(data_dir,
name_zip="benchmark_table_data_generated_500_tables_category_1.zip",
url="https://at.ispras.ru/owncloud/index.php/s/gItWxupnF2pve6B/download")
mode_metric_structure_only = True

# make common ground-truth file
common_gt_json = {}
for pathname in Path.iterdir(path_gt):
image_name = pathname.name.split(".")[0] + '.png'
with open(pathname, "r") as fp:
table_html = fp.read()
# exclude header tags
table_html = table_html.replace("<th ", "<td ")
table_html = table_html.replace("</th>", "</td>")

common_gt_json[image_name] = {"html": table_html}

file_common_gt = data_dir / "common_gt.json"
with file_common_gt.open("w") as fd:
json.dump(common_gt_json, fd, indent=2, ensure_ascii=False)

# calculate metrics
path_pred = data_dir / "pred.json"

pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=common_gt_json,
structure_only=mode_metric_structure_only,
ignore_nodes=['span', 'style', 'head', 'h4'])

result = dict()
result["mode_metric_structure_only"] = mode_metric_structure_only
result["mean"] = np.mean([score for score in scores.values()])
result["images"] = scores

return result


if __name__ == "__main__":
result = benchmark_on_our_data() if TYPE_BENCHMARK == OURDATA_BENCHMARK else benchmark_on_generated_table()

# save benchmarks
file_result = path_result / f"table_benchmark_{TYPE_BENCHMARK}.json"
with file_result.open("w") as fd:
json.dump(result, fd, indent=2, ensure_ascii=False)
Loading

0 comments on commit 189c267

Please sign in to comment.