Skip to content

Commit

Permalink
TLDR-590 fix code style in scripts directory
Browse files Browse the repository at this point in the history
  • Loading branch information
NastyBoget committed Feb 1, 2024
1 parent 6a3cfed commit 18463e8
Show file tree
Hide file tree
Showing 21 changed files with 242 additions and 248 deletions.
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ exclude =
.github,
*__init__.py,
resources,
scripts,
venv,
build,
dedoc.egg-info
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
exclude: \.github|.*__init__\.py|resources|scripts|examples|docs|venv|build|dedoc\.egg-info
exclude: \.github|.*__init__\.py|resources|docs|venv|build|dedoc\.egg-info
args:
- "--config=.flake8"
additional_dependencies: [
Expand Down
14 changes: 7 additions & 7 deletions scripts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,18 @@ def get_cpu_performance() -> float:


cpu_performance = get_cpu_performance()
print('"cpu_performance" = {}'.format(cpu_performance))
print(f'"cpu_performance" = {cpu_performance}') # noqa

with TemporaryDirectory() as path_base:
path_out = os.path.join(path_base, "dataset.zip")
wget.download(data_url, path_out)
with zipfile.ZipFile(path_out, 'r') as zip_ref:
with zipfile.ZipFile(path_out, "r") as zip_ref:
zip_ref.extractall(path_base)
print(path_base)
print(path_base) # noqa

failed = []
result = OrderedDict()
result["version"] = requests.get("{}/version".format(host)).text
result["version"] = requests.get(f"{host}/version").text
result["cpu_performance"] = cpu_performance
tasks = [
Task("images", "images", {}),
Expand All @@ -60,7 +60,7 @@ def get_cpu_performance() -> float:
Task("pdf", "pdf", {"pdf_with_text_layer": "false"}),
Task("pdf_tables", "pdf_tables", {})
]
print(tasks)
print(tasks) # noqa
for directory, name, parameters in tasks:
total_size = 0
total_time = 0
Expand Down Expand Up @@ -90,5 +90,5 @@ def get_cpu_performance() -> float:

with open(path_result, "w") as file_out:
json.dump(obj=result, fp=file_out, indent=4, ensure_ascii=False)
print("save result in" + path_result)
print(failed)
print(f"save result in {path_result}") # noqa
print(failed) # noqa
16 changes: 8 additions & 8 deletions scripts/benchmark_pdf_attachments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_reader_attachments(reader: BaseReader, input_dir: str, attachments_dir:
shutil.copy(attachment.tmp_file_path, os.path.join(file_attachments_dir, attachment_name))
attachment_names.append(attachment_name)

print(f"{file_name}: {len(attachment_names)} attachments, {len(document.attachments)} in result")
print(f"{file_name}: {len(attachment_names)} attachments, {len(document.attachments)} in result") # noqa
result_dict[file_name] = sorted(attachment_names)

return result_dict
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_attachments(attachments_extractor: AbstractAttachmentsExtractor, input_d
shutil.copy(attachment.tmp_file_path, os.path.join(file_attachments_dir, attachment_name))
attachment_names.append(attachment_name)

print(f"{file_name}: {len(attachment_names)} attachments, {len(attachments)} in result")
print(f"{file_name}: {len(attachment_names)} attachments, {len(attachments)} in result") # noqa
result_dict[file_name] = sorted(attachment_names)

return result_dict
Expand Down Expand Up @@ -99,9 +99,9 @@ def _get_attachment_name(attachment: AttachedFile, png_files: int, json_files: i
zip_ref.extractall(data_dir)
os.remove(archive_path)

print(f"Benchmark data downloaded to {data_dir}")
print(f"Benchmark data downloaded to {data_dir}") # noqa
else:
print(f"Use cached benchmark data from {data_dir}")
print(f"Use cached benchmark data from {data_dir}") # noqa

in_dir = os.path.join(data_dir, "with_attachments")
out_dir = os.path.join(in_dir, "extracted_attachments")
Expand All @@ -112,17 +112,17 @@ def _get_attachment_name(attachment: AttachedFile, png_files: int, json_files: i

benchmarks_dict = {}

print("Get tabby attachments")
print("Get tabby attachments") # noqa
tabby_reader = PdfTabbyReader(config={})
tabby_out_dir = os.path.join(out_dir, "tabby")
benchmarks_dict["tabby"] = get_reader_attachments(reader=tabby_reader, input_dir=in_dir, attachments_dir=tabby_out_dir)

print("Get pdfminer attachments")
print("Get pdfminer attachments") # noqa
pdfminer_reader = PdfTxtlayerReader(config={})
pdfminer_out_dir = os.path.join(out_dir, "pdfminer")
benchmarks_dict["pdfminer"] = get_reader_attachments(reader=pdfminer_reader, input_dir=in_dir, attachments_dir=pdfminer_out_dir)

print("Get common attachments")
print("Get common attachments") # noqa
common_out_dir = os.path.join(out_dir, "common")
pdf_attachments_extractor = PDFAttachmentsExtractor(config={})
benchmarks_dict["common"] = get_attachments(attachments_extractor=pdf_attachments_extractor, input_dir=in_dir, attachments_dir=common_out_dir)
Expand All @@ -131,4 +131,4 @@ def _get_attachment_name(attachment: AttachedFile, png_files: int, json_files: i
with open(os.path.join(json_out_dir, "benchmark_pdf_attachments.json"), "w") as f:
json.dump(benchmarks_dict, f, ensure_ascii=False, indent=2)

print(f"Attachments were extracted to {out_dir}")
print(f"Attachments were extracted to {out_dir}") # noqa
12 changes: 6 additions & 6 deletions scripts/benchmark_pdf_miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@
wget.download(URL, pdfs_zip_path)
wget.download(URL_GT, pdfs_zip_gt_path)

with zipfile.ZipFile(pdfs_zip_path, 'r') as zip_ref:
with zipfile.ZipFile(pdfs_zip_path, "r") as zip_ref:
zip_ref.extractall(data_dir)
os.remove(pdfs_zip_path)
with zipfile.ZipFile(pdfs_zip_gt_path, 'r') as zip_ref:
with zipfile.ZipFile(pdfs_zip_gt_path, "r") as zip_ref:
zip_ref.extractall(data_dir)
os.remove(pdfs_zip_gt_path)

print(f"Benchmark data downloaded to {data_dir}")
print(f"Benchmark data downloaded to {data_dir}") # noqa
else:
print(f"Use cached benchmark data from {data_dir}")
print(f"Use cached benchmark data from {data_dir}") # noqa

pdfs_path = data_dir / "PdfMiner Params"
pdfs_gt_path = data_dir / "PdfMiner Params GT"
Expand All @@ -53,7 +53,7 @@
accuracy_path = Path(tmpdir) / "accuracy.txt"
if accuracy_path.exists():
accuracy_path.unlink()
command = f"{accuracy_script_path} \"{gt_path}\" {tmp_ocr_path} >> {accuracy_path}"
command = f'{accuracy_script_path} "{gt_path}" {tmp_ocr_path} >> {accuracy_path}'
os.system(command)

with open(accuracy_path, "r") as f:
Expand All @@ -68,4 +68,4 @@
with (Path(output_dir) / "benchmark_pdf_miner.json").open("w") as f:
json.dump(info, f, ensure_ascii=False, indent=2)

print(f"save result in {output_dir}")
print(f"save result in {output_dir}") # noqa
31 changes: 14 additions & 17 deletions scripts/benchmark_table/benchmark_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import zipfile
from pathlib import Path
import json
import pprint
from typing import Optional, List
import zipfile
from pathlib import Path
from typing import List, Optional

import numpy as np
import wget

Expand Down Expand Up @@ -47,7 +48,7 @@ def get_tables(image_path: Path) -> str:
def make_predict_json(data_path: Path) -> dict:
predict_json = {}
for pathname in Path.iterdir(data_path):
print(pathname)
print(pathname) # noqa

predict_json[pathname.name] = {"html": "<html><body>" + get_tables(pathname) + "</body></html>"}

Expand All @@ -56,18 +57,18 @@ def make_predict_json(data_path: Path) -> dict:

def download_dataset(data_dir: Path, name_zip: str, url: str) -> None:
if Path.exists(data_dir):
print(f"Use cached benchmark data from {data_dir}")
print(f"Use cached benchmark data from {data_dir}") # noqa
return

data_dir.mkdir(parents=True, exist_ok=True)
pdfs_zip_path = data_dir / name_zip
wget.download(url, str(data_dir))

with zipfile.ZipFile(pdfs_zip_path, 'r') as zip_ref:
with zipfile.ZipFile(pdfs_zip_path, "r") as zip_ref:
zip_ref.extractall(data_dir)
pdfs_zip_path.unlink()

print(f"Benchmark data downloaded to {data_dir}")
print(f"Benchmark data downloaded to {data_dir}") # noqa


def prediction(path_pred: Path, path_images: Path) -> dict:
Expand All @@ -83,19 +84,17 @@ def benchmark_on_our_data() -> dict:
path_images = data_dir / "images"
path_gt = data_dir / "gt.json"
path_pred = data_dir / "pred.json"
download_dataset(data_dir,
name_zip="benchmark_table_data.zip",
url="https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download")
download_dataset(data_dir, name_zip="benchmark_table_data.zip", url="https://at.ispras.ru/owncloud/index.php/s/Xaf4OyHj6xN2RHH/download")

mode_metric_structure_only = False

with open(path_gt, "r") as fp:
gt_json = json.load(fp)
'''
"""
Creating base html (based on method predictions for future labeling)
path_images = data_dir / "images_tmp"
pred_json = prediction("gt_tmp.json", path_images)
'''
"""
pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=gt_json, structure_only=mode_metric_structure_only)

Expand All @@ -113,7 +112,7 @@ def benchmark_on_generated_table() -> dict:
Article generation information https://arxiv.org/pdf/1905.13391.pdf
Note: generate the 1st table tape category
Note: don't use header table tag <th>, replacing on <td> tag
Note: all generated data (four categories) you can download from
Note: all generated data (four categories) you can download from
TODO: some tables have a low quality. Should to trace the reason.
All generated data (all categories) we can download from https://at.ispras.ru/owncloud/index.php/s/cjpCIR7I0G4JzZU
"""
Expand All @@ -129,7 +128,7 @@ def benchmark_on_generated_table() -> dict:
# make common ground-truth file
common_gt_json = {}
for pathname in Path.iterdir(path_gt):
image_name = pathname.name.split(".")[0] + '.png'
image_name = pathname.name.split(".")[0] + ".png"
with open(pathname, "r") as fp:
table_html = fp.read()
# exclude header tags
Expand All @@ -146,9 +145,7 @@ def benchmark_on_generated_table() -> dict:
path_pred = data_dir / "pred.json"

pred_json = prediction(path_pred, path_images)
scores = call_metric(pred_json=pred_json, true_json=common_gt_json,
structure_only=mode_metric_structure_only,
ignore_nodes=['span', 'style', 'head', 'h4'])
scores = call_metric(pred_json=pred_json, true_json=common_gt_json, structure_only=mode_metric_structure_only, ignore_nodes=["span", "style", "head", "h4"])

result = dict()
result["mode_metric_structure_only"] = mode_metric_structure_only
Expand Down
59 changes: 34 additions & 25 deletions scripts/benchmark_table/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,30 @@

# Source: https://github.com/ibm-aur-nlp/PubTabNet

from collections import deque
from typing import Optional

import distance
from apted import APTED, Config
from apted.helpers import Tree
from lxml import etree, html
from collections import deque

from tqdm import tqdm


class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, visible=None, *children):
def __init__(self, tag: str, colspan=None, rowspan=None, content=None, visible=None, *children): # noqa
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.visible = visible
self.children = list(children)

def bracket(self):
"""Show tree using brackets notation
def bracket(self) -> str:
"""
if self.tag == "td" or self.tag == 'th':
Show tree using brackets notation
"""
if self.tag == "td" or self.tag == "th":
result = f'"tag": {self.tag}, "colspan": {self.colspan}, "rowspan": {self.rowspan}, "text": {self.content}'
else:
result = f'"tag": {self.tag}'
Expand All @@ -43,18 +45,22 @@ def bracket(self):

class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""Get maximum possible value
def maximum(*sequences): # noqa
"""
Get maximum possible value
"""
return max(map(len, sequences))

def normalized_distance(self, *sequences) -> float:
"""Get distance from 0 to 1
def normalized_distance(self, *sequences) -> float: # noqa
"""
Get distance from 0 to 1
"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

def rename(self, node1: TableTree, node2: TableTree) -> float:
"""Compares attributes of trees"""
"""
Compares attributes of trees
"""
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
if node1.tag == "td":
Expand All @@ -66,18 +72,20 @@ def rename(self, node1: TableTree, node2: TableTree) -> float:


class TEDS(object):
""" Tree Edit Distance based Similarity
"""
Tree Edit Distance based Similarity
"""

def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
def __init__(self, structure_only: bool = False, n_jobs: int = 1, ignore_nodes: Optional[list] = None) -> None:
assert isinstance(n_jobs, int) and (n_jobs >= 1), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []

def tokenize(self, node):
""" Tokenizes table cells
def tokenize(self, node: TableTree) -> None:
"""
Tokenizes table cells
"""
self.__tokens__.append(f"<{node.tag}>")
if node.text is not None:
Expand All @@ -89,11 +97,11 @@ def tokenize(self, node):
if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)

def get_span(self, node, name_span: str) -> int:
def get_span(self, node: TableTree, name_span: str) -> int:
value = int(node.attrib.get(name_span, "1"))
return 1 if value <= 0 else value

def load_html_tree(self, node, parent=None):
def load_html_tree(self, node: TableTree, parent: Optional[TableTree] = None) -> TableTree:
""" Converts HTML tree to the format required by apted
"""
if node.tag == "td":
Expand All @@ -109,9 +117,9 @@ def load_html_tree(self, node, parent=None):
colspan=self.get_span(node, "colspan"),
rowspan=self.get_span(node, "rowspan"),
content=cell,
visible=False if node.attrib.get("style") == "display: none" else True, *deque())
visible=False if node.attrib.get("style") == "display: none" else True, *deque()) # noqa
except Exception as ex:
print(f"Bad html file. HTML parse exception. Exception's msg: {ex}")
print(f"Bad html file. HTML parse exception. Exception's msg: {ex}") # noqa
raise ex
else:
new_node = TableTree(node.tag, None, None, None, True, *deque())
Expand Down Expand Up @@ -148,12 +156,13 @@ def evaluate(self, pred: str, true: str) -> float:
else:
return 0.0

def batch_evaluate(self, pred_json, true_json):
""" Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@output: {'FILENAME': 'TEDS SCORE', ...}
def batch_evaluate(self, pred_json: dict, true_json: dict) -> dict:
"""
Computes TEDS score between the prediction and the ground truth of a batch of samples
:param pred_json: {'FILENAME': 'HTML CODE', ...}
:param true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
:return: {'FILENAME': 'TEDS SCORE', ...}
"""
samples = true_json.keys()
scores = [self.evaluate(pred_json.get(filename, "")["html"], true_json[filename]["html"]) for filename in tqdm(samples)]
Expand Down
Loading

0 comments on commit 18463e8

Please sign in to comment.