Skip to content

Commit

Permalink
TLDR-585 fixed bug, include cells's content in metric
Browse files Browse the repository at this point in the history
  • Loading branch information
oksidgy committed Jan 24, 2024
1 parent 26dfa77 commit bf0c7a8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 50 deletions.
23 changes: 12 additions & 11 deletions resources/benchmarks/table_benchmark.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
{
"mean": 0.9824606866114314,
"mode_metric_structure_only": false,
"mean": 0.9468374367023571,
"images": {
"example_with_table0_0.png": 0.9873417721518988,
"example_with_table0_1.png": 1.0,
"example_with_table6.png": 1.0,
"example_with_table4.jpg": 1.0,
"example_with_table17.jpg": 0.8536585365853658,
"example_with_table_hor_vert_union.png": 1.0,
"example_with_table1.png": 1.0,
"example_with_table_horizontal_union.jpg": 1.0,
"example_with_table3.png": 1.0,
"example_with_table5.png": 0.9836065573770492
"example_with_table0_0.png": 0.9525583036909738,
"example_with_table0_1.png": 0.9264351862896008,
"example_with_table6.png": 0.989010989010989,
"example_with_table4.jpg": 0.908436211832951,
"example_with_table17.jpg": 0.8078952936402488,
"example_with_table_hor_vert_union.png": 0.9896091617933723,
"example_with_table1.png": 0.9781560283687943,
"example_with_table_horizontal_union.jpg": 0.9925757575757576,
"example_with_table3.png": 0.9778008866078716,
"example_with_table5.png": 0.9458965482130129
}
}
15 changes: 9 additions & 6 deletions scripts/benchmark_table/benchmark_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@
from dedoc.readers.pdf_reader.pdf_image_reader.table_recognizer.table_recognizer import TableRecognizer
from scripts.benchmark_table.metric import TEDS

path_result = Path(__file__).parent / ".." / "resources" / "benchmarks"
path_result = Path(__file__).parent / ".." / ".." / "resources" / "benchmarks"
path_result.absolute().mkdir(parents=True, exist_ok=True)

URL = "https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download"

table_recognizer = TableRecognizer(config=get_config())
image_reader = PdfImageReader(config=get_config())
teds = TEDS()


def call_metric(pred_json: dict, true_json: dict) -> dict:
def call_metric(pred_json: dict, true_json: dict, structure_only: bool = False) -> dict:
teds = TEDS(structure_only=structure_only)
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)
Expand Down Expand Up @@ -83,21 +83,24 @@ def prediction(path_pred: Path, path_images: Path) -> dict:
path_pred = data_dir / "pred.json"
download_dataset(data_dir)

mode_metric_structure_only = False

with open(path_gt, "r") as fp:
gt_json = json.load(fp)
'''
Creating base html (based on method predictions for future labeling)
path_images = data_dir / "images_tmp"
pred_json = prediction("gt_tmp.json", path_images)
'''
pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=gt_json)
pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=gt_json, structure_only=mode_metric_structure_only)

result = dict()
result["mode_metric_structure_only"] = mode_metric_structure_only
result["mean"] = np.mean([score for score in scores.values()])
result["images"] = scores

# save benchmarks
file_result = path_result / "table_benchmark.json"
with file_result.open("w") as fd:
json.dump(str(file_result), fd, indent=2, ensure_ascii=False)
json.dump(result, fd, indent=2, ensure_ascii=False)
66 changes: 33 additions & 33 deletions scripts/benchmark_table/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ def __init__(self, tag, colspan=None, rowspan=None, content=None, visible=None,
self.children = list(children)

def bracket(self):
"""Show tree using brackets notation"""
if self.tag == 'td':
"""Show tree using brackets notation
"""
if self.tag == "td":
result = f'"tag": {self.tag}, "colspan": {self.colspan}, "rowspan": {self.rowspan}, "text": {self.content}'
else:
result = f'"tag": {self.tag}'
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
return "{{" + result + "}}"


class CustomConfig(Config):
Expand All @@ -54,44 +55,44 @@ def normalized_distance(self, *sequences) -> float:

def rename(self, node1: TableTree, node2: TableTree) -> float:
"""Compares attributes of trees"""
if not node1.visible or node2.visible:
return 0.
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == 'td':
if node1.tag == "td":
if not node1.visible or not node2.visible:
return 0.
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.


class TEDS(object):
''' Tree Edit Distance basead Similarity
'''
""" Tree Edit Distance based Similarity
"""

def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (n_jobs >= 1), 'n_jobs must be an integer greather than 1'
assert isinstance(n_jobs, int) and (n_jobs >= 1), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []

def tokenize(self, node):
''' Tokenizes table cells
'''
self.__tokens__.append('<%s>' % node.tag)
""" Tokenizes table cells
"""
self.__tokens__.append(f"<{node.tag}>")
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != 'unk':
self.__tokens__.append('</%s>' % node.tag)
if node.tag != 'td' and node.tail is not None:
if node.tag != "unk":
self.__tokens__.append(f"</{node.tag}>")
if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)

def load_html_tree(self, node, parent=None):
''' Converts HTML tree to the format required by apted
'''
global __tokens__
if node.tag == 'td':
""" Converts HTML tree to the format required by apted
"""
if node.tag == "td":
if self.structure_only:
cell = []
else:
Expand All @@ -101,35 +102,34 @@ def load_html_tree(self, node, parent=None):

try:
new_node = TableTree(tag=node.tag,
colspan=int(node.attrib.get('colspan', '1')),
rowspan=int(node.attrib.get('rowspan', '1')),
colspan=int(node.attrib.get("colspan", "1")),
rowspan=int(node.attrib.get("rowspan", "1")),
content=cell,
visible=False if node.attrib.get('style') == "display: none" else True, *deque())
visible=False if node.attrib.get("style") == "display: none" else True, *deque())
except Exception as ex:
print(f"Bad html file. HTML parse exception. Exception's msg: {ex}")
raise ex
else:
new_node = TableTree(node.tag, None, None, None, True, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != 'td':
if node.tag != "td":
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node

def evaluate(self, pred: str, true: str) -> float:
''' Computes TEDS score between the prediction and the ground truth of a
given sample
'''
""" Computes TEDS score between the prediction and the ground truth of a given sample
"""
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding='utf-8')
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath('body/table') and true.xpath('body/table'):
pred = pred.xpath('body/table')[0]
true = true.xpath('body/table')[0]
if pred.xpath("body/table") and true.xpath("body/table"):
pred = pred.xpath("body/table")[0]
true = true.xpath("body/table")[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
Expand All @@ -145,13 +145,13 @@ def evaluate(self, pred: str, true: str) -> float:
return 0.0

def batch_evaluate(self, pred_json, true_json):
''' Computes TEDS score between the prediction and the ground truth of
""" Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
'''
"""
samples = true_json.keys()
scores = [self.evaluate(pred_json.get(filename, '')['html'], true_json[filename]['html']) for filename in tqdm(samples)]
scores = [self.evaluate(pred_json.get(filename, "")["html"], true_json[filename]["html"]) for filename in tqdm(samples)]
scores = dict(zip(samples, scores))
return scores

0 comments on commit bf0c7a8

Please sign in to comment.