Skip to content

Commit

Permalink
TLDR-585 added TEDS table benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
oksidgy committed Jan 22, 2024
1 parent 0b7ea01 commit 93c0c08
Show file tree
Hide file tree
Showing 6 changed files with 299 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dedoc/api/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def json2html(text: str, paragraph: TreeNode, tables: Optional[List[Table]], tab
if tables is not None and len(tables) > 0:
text += "<h3> Tables: </h3>"
for table in tables:
text += __table2html(table, table2id)
text += table2html(table, table2id)
text += "<p>&nbsp;</p>"
return text

Expand Down Expand Up @@ -201,7 +201,7 @@ def __annotations2html(paragraph: TreeNode, table2id: Dict[str, int]) -> str:
return text.replace("\n", "<br>")


def __table2html(table: Table, table2id: Dict[str, int]) -> str:
def table2html(table: Table, table2id: Dict[str, int]) -> str:
uid = table.metadata.uid
text = f"<h4> table {table2id[uid]}:</h4>"
text += f'<table border="1" id={uid} style="border-collapse: collapse; width: 100%;">\n<tbody>\n'
Expand Down
13 changes: 7 additions & 6 deletions dedoc/readers/pdf_reader/pdf_base_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,20 @@ def read(self, file_path: str, parameters: Optional[dict] = None) -> Unstructure
)

lines, scan_tables, attachments, warnings, other_fields = self._parse_document(file_path, params_for_parse)
tables = []
for scan_table in scan_tables:
metadata = TableMetadata(page_id=scan_table.page_number, uid=scan_table.name, rotated_angle=scan_table.location.rotated_angle)
cells_with_meta = [[CellWithMeta.create_from_cell(cell) for cell in row] for row in scan_table.matrix_cells]
table = Table(metadata=metadata, cells=cells_with_meta)
tables.append(table)
tables = [self.scantable2table(scan_table) for scan_table in scan_tables]

if self._can_contain_attachements(file_path) and self.attachment_extractor.with_attachments(parameters):
attachments += self.attachment_extractor.extract(file_path=file_path, parameters=parameters)

result = UnstructuredDocument(lines=lines, tables=tables, attachments=attachments, warnings=warnings, metadata=other_fields)
return self._postprocess(result)

@staticmethod
def scantable2table(table: ScanTable) -> Table:
metadata = TableMetadata(page_id=table.page_number, uid=table.name, rotated_angle=table.location.rotated_angle)
cells_with_meta = [[CellWithMeta.create_from_cell(cell) for cell in row] for row in table.matrix_cells]
return Table(metadata=metadata, cells=cells_with_meta)

def _can_contain_attachements(self, path: str) -> bool:
can_contain_attachments = False
mime = get_file_mime_type(path)
Expand Down
15 changes: 15 additions & 0 deletions resources/benchmarks/table_benchmark.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"mean": 0.9824606866114314,
"images": {
"example_with_table0_0.png": 0.9873417721518988,
"example_with_table0_1.png": 1.0,
"example_with_table6.png": 1.0,
"example_with_table4.jpg": 1.0,
"example_with_table17.jpg": 0.8536585365853658,
"example_with_table_hor_vert_union.png": 1.0,
"example_with_table1.png": 1.0,
"example_with_table_horizontal_union.jpg": 1.0,
"example_with_table3.png": 1.0,
"example_with_table5.png": 0.9836065573770492
}
}
114 changes: 114 additions & 0 deletions scripts/benchmark_table/benchmark_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os
import zipfile
from pathlib import Path
import json
import pprint

import numpy as np
import wget

from dedoc.api.api_utils import table2html
from dedoc.config import get_config
from dedoc.readers import PdfImageReader
from dedoc.readers.pdf_reader.pdf_image_reader.table_recognizer.table_recognizer import TableRecognizer
from scripts.benchmark_table.metric import TEDS
from tests.test_utils import get_test_config

path_result = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "benchmarks")
path_result = os.path.abspath(path_result)
os.makedirs(path_result, exist_ok=True)
path_result = os.path.join(path_result, "table_benchmark.json")

URL = "https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download"

table_recognizer = TableRecognizer(config=get_test_config())
image_reader = PdfImageReader(config=get_test_config())
teds = TEDS()


def call_metric(pred_json: dict, true_json: dict) -> dict:
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)

return scores


def get_tables(image_path: str) -> str:
document = image_reader.read(image_path)

for table in document.tables:
table.metadata.uid = "test_id"
table2id = {"test_id": 0}
html_tables = [table2html(table, table2id) for table in document.tables]

# TODO: while works with one table in an image
return html_tables[0]


def make_predict_json(data_path: Path) -> dict:
predict_json = {}
for filename in os.listdir(data_path):
print(filename)
file_path = str(data_path / filename)

predict_json[filename] = {"html": "<html><body>" + get_tables(file_path) + "</body></html>"}

return predict_json


def download_dataset(data_dir: Path) -> None:

if not os.path.isdir(data_dir):
data_dir.mkdir(parents=True)
pdfs_zip_path = str(data_dir / "benchmark_table_data.zip")
wget.download(URL, pdfs_zip_path)

with zipfile.ZipFile(pdfs_zip_path, 'r') as zip_ref:
zip_ref.extractall(data_dir)
os.remove(pdfs_zip_path)

print(f"Benchmark data downloaded to {data_dir}")
else:
print(f"Use cached benchmark data from {data_dir}")


def prediction(path_pred: Path, path_images: Path) -> dict:
pred_json = make_predict_json(path_images)
with open(path_pred, "w") as fd:
json.dump(pred_json, fd, indent=2, ensure_ascii=False)

return pred_json


if __name__ == "__main__":
data_dir = Path(get_config()["intermediate_data_path"]) / "benchmark_table_data"
path_images = data_dir / "images"
path_gt = data_dir / "gt.json"
path_pred = data_dir / "pred.json"
download_dataset(data_dir)

with open(path_gt, "r") as fp:
gt_json = json.load(fp)
'''
Creating base html (based on method predictions for future labeling)
path_images = data_dir / "images_tmp"
pred_json = prediction("gt_tmp.json", path_images)
'''
pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=gt_json)

result = dict()
result["mean"] = np.mean([score for score in scores.values()])
result["images"] = scores

# save benchmarks
with open(path_result, "w") as fd:
json.dump(result, fd, indent=2, ensure_ascii=False)







158 changes: 158 additions & 0 deletions scripts/benchmark_table/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright 2020 IBM
# Author: [email protected]
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.

# Source: https://github.com/ibm-aur-nlp/PubTabNet

import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import etree, html
from collections import deque

from tqdm import tqdm


class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, visible=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.visible = visible
self.children = list(children)

def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)


class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
"""
return max(map(len, sequences))

def normalized_distance(self, *sequences) -> float:
"""Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

def rename(self, node1: TableTree, node2: TableTree) -> float:
"""Compares attributes of trees"""
if not node1.visible or node2.visible:
return 0.
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.


class TEDS(object):
''' Tree Edit Distance basead Similarity
'''
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []

def tokenize(self, node):
''' Tokenizes table cells
'''
self.__tokens__.append('<%s>' % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != 'unk':
self.__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
self.__tokens__ += list(node.tail)

def load_html_tree(self, node, parent=None):
''' Converts HTML tree to the format required by apted
'''
global __tokens__
if node.tag == 'td':
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()

try:
new_node = TableTree(tag=node.tag,
colspan=int(node.attrib.get('colspan', '1')),
rowspan=int(node.attrib.get('rowspan', '1')),
content=cell,
visible=False if node.attrib.get('style') == "display: none" else True, *deque())
except Exception as ex:
print(f"Bad html file. HTML parse exception. Exception's msg: {ex}")
raise ex
else:
new_node = TableTree(node.tag, None, None, None, True, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node

def evaluate(self, pred: str, true: str) -> float:
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)

distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0

def batch_evaluate(self, pred_json, true_json):
''' Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
samples = true_json.keys()
scores = [self.evaluate(pred_json.get(filename, '')['html'], true_json[filename]['html']) for filename in tqdm(samples)]
scores = dict(zip(samples, scores))
return scores
3 changes: 3 additions & 0 deletions scripts/benchmark_table/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# for metric TEDS:
apted==1.0.3
distance==0.1.3

0 comments on commit 93c0c08

Please sign in to comment.