-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
299 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"mean": 0.9824606866114314, | ||
"images": { | ||
"example_with_table0_0.png": 0.9873417721518988, | ||
"example_with_table0_1.png": 1.0, | ||
"example_with_table6.png": 1.0, | ||
"example_with_table4.jpg": 1.0, | ||
"example_with_table17.jpg": 0.8536585365853658, | ||
"example_with_table_hor_vert_union.png": 1.0, | ||
"example_with_table1.png": 1.0, | ||
"example_with_table_horizontal_union.jpg": 1.0, | ||
"example_with_table3.png": 1.0, | ||
"example_with_table5.png": 0.9836065573770492 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
import zipfile | ||
from pathlib import Path | ||
import json | ||
import pprint | ||
|
||
import numpy as np | ||
import wget | ||
|
||
from dedoc.api.api_utils import table2html | ||
from dedoc.config import get_config | ||
from dedoc.readers import PdfImageReader | ||
from dedoc.readers.pdf_reader.pdf_image_reader.table_recognizer.table_recognizer import TableRecognizer | ||
from scripts.benchmark_table.metric import TEDS | ||
from tests.test_utils import get_test_config | ||
|
||
path_result = os.path.join(os.path.dirname(__file__), "..", "..", "resources", "benchmarks") | ||
path_result = os.path.abspath(path_result) | ||
os.makedirs(path_result, exist_ok=True) | ||
path_result = os.path.join(path_result, "table_benchmark.json") | ||
|
||
URL = "https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download" | ||
|
||
table_recognizer = TableRecognizer(config=get_test_config()) | ||
image_reader = PdfImageReader(config=get_test_config()) | ||
teds = TEDS() | ||
|
||
|
||
def call_metric(pred_json: dict, true_json: dict) -> dict: | ||
scores = teds.batch_evaluate(pred_json, true_json) | ||
pp = pprint.PrettyPrinter() | ||
pp.pprint(scores) | ||
|
||
return scores | ||
|
||
|
||
def get_tables(image_path: str) -> str: | ||
document = image_reader.read(image_path) | ||
|
||
for table in document.tables: | ||
table.metadata.uid = "test_id" | ||
table2id = {"test_id": 0} | ||
html_tables = [table2html(table, table2id) for table in document.tables] | ||
|
||
# TODO: while works with one table in an image | ||
return html_tables[0] | ||
|
||
|
||
def make_predict_json(data_path: Path) -> dict: | ||
predict_json = {} | ||
for filename in os.listdir(data_path): | ||
print(filename) | ||
file_path = str(data_path / filename) | ||
|
||
predict_json[filename] = {"html": "<html><body>" + get_tables(file_path) + "</body></html>"} | ||
|
||
return predict_json | ||
|
||
|
||
def download_dataset(data_dir: Path) -> None: | ||
|
||
if not os.path.isdir(data_dir): | ||
data_dir.mkdir(parents=True) | ||
pdfs_zip_path = str(data_dir / "benchmark_table_data.zip") | ||
wget.download(URL, pdfs_zip_path) | ||
|
||
with zipfile.ZipFile(pdfs_zip_path, 'r') as zip_ref: | ||
zip_ref.extractall(data_dir) | ||
os.remove(pdfs_zip_path) | ||
|
||
print(f"Benchmark data downloaded to {data_dir}") | ||
else: | ||
print(f"Use cached benchmark data from {data_dir}") | ||
|
||
|
||
def prediction(path_pred: Path, path_images: Path) -> dict: | ||
pred_json = make_predict_json(path_images) | ||
with open(path_pred, "w") as fd: | ||
json.dump(pred_json, fd, indent=2, ensure_ascii=False) | ||
|
||
return pred_json | ||
|
||
|
||
if __name__ == "__main__": | ||
data_dir = Path(get_config()["intermediate_data_path"]) / "benchmark_table_data" | ||
path_images = data_dir / "images" | ||
path_gt = data_dir / "gt.json" | ||
path_pred = data_dir / "pred.json" | ||
download_dataset(data_dir) | ||
|
||
with open(path_gt, "r") as fp: | ||
gt_json = json.load(fp) | ||
''' | ||
Creating base html (based on method predictions for future labeling) | ||
path_images = data_dir / "images_tmp" | ||
pred_json = prediction("gt_tmp.json", path_images) | ||
''' | ||
pred_json = prediction(path_pred, path_images) | ||
scores = call_metric(pred_json=pred_json, true_json=gt_json) | ||
|
||
result = dict() | ||
result["mean"] = np.mean([score for score in scores.values()]) | ||
result["images"] = scores | ||
|
||
# save benchmarks | ||
with open(path_result, "w") as fd: | ||
json.dump(result, fd, indent=2, ensure_ascii=False) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
# Copyright 2020 IBM | ||
# Author: [email protected] | ||
# | ||
# This is free software; you can redistribute it and/or modify | ||
# it under the terms of the Apache 2.0 License. | ||
# | ||
# This software is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# Apache 2.0 License for more details. | ||
|
||
# Source: https://github.com/ibm-aur-nlp/PubTabNet | ||
|
||
import distance | ||
from apted import APTED, Config | ||
from apted.helpers import Tree | ||
from lxml import etree, html | ||
from collections import deque | ||
|
||
from tqdm import tqdm | ||
|
||
|
||
class TableTree(Tree): | ||
def __init__(self, tag, colspan=None, rowspan=None, content=None, visible=None, *children): | ||
self.tag = tag | ||
self.colspan = colspan | ||
self.rowspan = rowspan | ||
self.content = content | ||
self.visible = visible | ||
self.children = list(children) | ||
|
||
def bracket(self): | ||
"""Show tree using brackets notation""" | ||
if self.tag == 'td': | ||
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ | ||
(self.tag, self.colspan, self.rowspan, self.content) | ||
else: | ||
result = '"tag": %s' % self.tag | ||
for child in self.children: | ||
result += child.bracket() | ||
return "{{{}}}".format(result) | ||
|
||
|
||
class CustomConfig(Config): | ||
@staticmethod | ||
def maximum(*sequences): | ||
"""Get maximum possible value | ||
""" | ||
return max(map(len, sequences)) | ||
|
||
def normalized_distance(self, *sequences) -> float: | ||
"""Get distance from 0 to 1 | ||
""" | ||
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) | ||
|
||
def rename(self, node1: TableTree, node2: TableTree) -> float: | ||
"""Compares attributes of trees""" | ||
if not node1.visible or node2.visible: | ||
return 0. | ||
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): | ||
return 1. | ||
if node1.tag == 'td': | ||
if node1.content or node2.content: | ||
return self.normalized_distance(node1.content, node2.content) | ||
return 0. | ||
|
||
|
||
class TEDS(object): | ||
''' Tree Edit Distance basead Similarity | ||
''' | ||
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): | ||
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1' | ||
self.structure_only = structure_only | ||
self.n_jobs = n_jobs | ||
self.ignore_nodes = ignore_nodes | ||
self.__tokens__ = [] | ||
|
||
def tokenize(self, node): | ||
''' Tokenizes table cells | ||
''' | ||
self.__tokens__.append('<%s>' % node.tag) | ||
if node.text is not None: | ||
self.__tokens__ += list(node.text) | ||
for n in node.getchildren(): | ||
self.tokenize(n) | ||
if node.tag != 'unk': | ||
self.__tokens__.append('</%s>' % node.tag) | ||
if node.tag != 'td' and node.tail is not None: | ||
self.__tokens__ += list(node.tail) | ||
|
||
def load_html_tree(self, node, parent=None): | ||
''' Converts HTML tree to the format required by apted | ||
''' | ||
global __tokens__ | ||
if node.tag == 'td': | ||
if self.structure_only: | ||
cell = [] | ||
else: | ||
self.__tokens__ = [] | ||
self.tokenize(node) | ||
cell = self.__tokens__[1:-1].copy() | ||
|
||
try: | ||
new_node = TableTree(tag=node.tag, | ||
colspan=int(node.attrib.get('colspan', '1')), | ||
rowspan=int(node.attrib.get('rowspan', '1')), | ||
content=cell, | ||
visible=False if node.attrib.get('style') == "display: none" else True, *deque()) | ||
except Exception as ex: | ||
print(f"Bad html file. HTML parse exception. Exception's msg: {ex}") | ||
raise ex | ||
else: | ||
new_node = TableTree(node.tag, None, None, None, True, *deque()) | ||
if parent is not None: | ||
parent.children.append(new_node) | ||
if node.tag != 'td': | ||
for n in node.getchildren(): | ||
self.load_html_tree(n, new_node) | ||
if parent is None: | ||
return new_node | ||
|
||
def evaluate(self, pred: str, true: str) -> float: | ||
''' Computes TEDS score between the prediction and the ground truth of a | ||
given sample | ||
''' | ||
if (not pred) or (not true): | ||
return 0.0 | ||
parser = html.HTMLParser(remove_comments=True, encoding='utf-8') | ||
pred = html.fromstring(pred, parser=parser) | ||
true = html.fromstring(true, parser=parser) | ||
if pred.xpath('body/table') and true.xpath('body/table'): | ||
pred = pred.xpath('body/table')[0] | ||
true = true.xpath('body/table')[0] | ||
if self.ignore_nodes: | ||
etree.strip_tags(pred, *self.ignore_nodes) | ||
etree.strip_tags(true, *self.ignore_nodes) | ||
n_nodes_pred = len(pred.xpath(".//*")) | ||
n_nodes_true = len(true.xpath(".//*")) | ||
n_nodes = max(n_nodes_pred, n_nodes_true) | ||
tree_pred = self.load_html_tree(pred) | ||
tree_true = self.load_html_tree(true) | ||
|
||
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance() | ||
return 1.0 - (float(distance) / n_nodes) | ||
else: | ||
return 0.0 | ||
|
||
def batch_evaluate(self, pred_json, true_json): | ||
''' Computes TEDS score between the prediction and the ground truth of | ||
a batch of samples | ||
@params pred_json: {'FILENAME': 'HTML CODE', ...} | ||
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} | ||
@output: {'FILENAME': 'TEDS SCORE', ...} | ||
''' | ||
samples = true_json.keys() | ||
scores = [self.evaluate(pred_json.get(filename, '')['html'], true_json[filename]['html']) for filename in tqdm(samples)] | ||
scores = dict(zip(samples, scores)) | ||
return scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# for metric TEDS: | ||
apted==1.0.3 | ||
distance==0.1.3 |