Skip to content

Commit

Permalink
Update TableTransformer Code with PostProcessing Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
iamayushanand committed May 5, 2024
1 parent e693b91 commit 191363b
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 27 deletions.
4 changes: 2 additions & 2 deletions detr/util/box_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def generalized_box_iou(boxes1, boxes2):
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
#assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
#assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)

lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
Expand Down
2 changes: 1 addition & 1 deletion scripts/view_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def main():
print(filename)
try:
xml_filepath = os.path.join(data_directory, split, filename)
img_filepath = xml_filepath.replace(split, "images").replace(".xml", ".jpg")
img_filepath = xml_filepath.replace(split, "images").replace(".xml", ".png")

bboxes, labels = read_pascal_voc(xml_filepath)
img = Image.open(img_filepath)
Expand Down
54 changes: 54 additions & 0 deletions src/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from inference import TableExtractionPipeline
from PIL import Image, ImageDraw
from fastapi import FastAPI, UploadFile, Body, File, Depends, Form
from pydantic import BaseModel, model_validator
from typing import List
from pdf2image import convert_from_bytes
import numpy as np
from io import BytesIO
import json
import logging
logging.basicConfig(filename='../../logs/fastapi_app.log', level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s')


app = FastAPI()
logging.info("Initialised FastAPI interface.")

pipe = TableExtractionPipeline(str_config_path='structure_config.json', str_model_path='../../model.pth', str_device='cuda')
logging.info("TSR pipeline initialised.")


@app.get("/")
async def return_cells(message: str = Form(...), files: List[UploadFile]=File(...)):
file_dict = {}
message = json.loads(message)
for f in files:
file_dict[f.filename] = BytesIO(await f.read())

logging.info("GET request recieved at the endpoint")
response = []
for bbox_result in message["bbox_result"]:
pdf_path = bbox_result["pdf_file"]

images = convert_from_bytes(file_dict[pdf_path].read(), dpi=300)
img = images[bbox_result["page_num"]-1].convert('RGB')
logging.info("PDF file loaded")
print(img.size)
for i, bbox in enumerate(bbox_result["bbox"]):
img_cropped = img.crop((bbox[0], bbox[1], bbox[2], bbox[3]))
draw = ImageDraw.Draw(img_cropped)
extracted_table = pipe.recognize(img_cropped, f"table_{bbox_result['page_num']}_{i}", [], out_cells=True)
for cell in extracted_table["cells"][0]:
bbx = cell["bbox"]
print(cell)
draw.rectangle(bbx, outline="red", width=10)

#extracted_table = pipe.recognize(img_cropped, f"table_{bbox_result['page_num']}_{i}", [], out_cells=True)
#for cell in extracted_table["cells"][0]:
# bbx = cell["bbox"]
# draw.rectangle(bbx, outline="blue", width=10)
img_cropped.save("cropped.png")
response.append(extracted_table["cells"])
logging.info(f"table_{bbox_result['page_num']}_{i} extracted successfully")
return response
178 changes: 169 additions & 9 deletions src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import matplotlib.patches as patches
from fitz import Rect
from PIL import Image
import traceback
import matplotlib.patches as patches
from matplotlib.patches import Patch

sys.path.append("../detr")
import util.misc as utils
Expand Down Expand Up @@ -278,9 +281,11 @@ def compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells,

# Compute grids/matrices for comparison
true_relspan_grid = np.array(grits.cells_to_relspan_grid(true_cells))
print("true:",true_relspan_grid.shape)
true_bbox_grid = np.array(grits.cells_to_grid(true_cells, key='bbox'))
true_text_grid = np.array(grits.cells_to_grid(true_cells, key='cell_text'), dtype=object)
pred_relspan_grid = np.array(grits.cells_to_relspan_grid(pred_cells))
print("pred:",pred_relspan_grid.shape)
pred_bbox_grid = np.array(grits.cells_to_grid(pred_cells, key='bbox'))
pred_text_grid = np.array(grits.cells_to_grid(pred_cells, key='cell_text'), dtype=object)

Expand Down Expand Up @@ -462,29 +467,170 @@ def eval_tsr_sample(target, pred_logits, pred_bboxes, mode):
img_words_filepath = target["img_words_path"]
with open(img_words_filepath, 'r') as f:
true_page_tokens = json.load(f)

true_table_structures, true_cells, _ = objects_to_cells(true_bboxes, true_labels, true_scores,
true_page_tokens, structure_class_names,
structure_class_thresholds, structure_class_map)

#print(true_table_structures)
#print(torch.max(pred_logits.softmax(-1), -1))
m = pred_logits.softmax(-1).max(-1)
#print("m")
pred_labels = list(m.indices.detach().cpu().numpy())
#print(pred_labels)
pred_scores = list(m.values.detach().cpu().numpy())
#print(pred_scores)
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, true_img_size)]
#print(pred_bboxes)
_, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores,
true_page_tokens, structure_class_names,
structure_class_thresholds, structure_class_map)

#print(pred_cells)
metrics = compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells,
pred_bboxes, pred_labels, pred_scores, pred_cells)
statistics = compute_statistics(true_table_structures, true_cells)

metrics.update(statistics)
metrics['id'] = target["img_path"].split('/')[-1].split('.')[0]

print(metrics)
return metrics


def visualize_better(labels, bboxes, target, debug_dir):
try:
img_filepath = target["img_path"]
img_filename = img_filepath.split("/")[-1]

bboxes_out_filename = img_filename.replace(".jpg", "_bboxes.jpg")
save_filepath = os.path.join(debug_dir, bboxes_out_filename)

img = Image.open(img_filepath)

ax = plt.gca()
ax.imshow(img, interpolation="lanczos")
plt.gcf().set_size_inches((24, 24))

tables = [bbox for bbox, label in zip(bboxes, labels) if label == 'table']
columns = [bbox for bbox, label in zip(bboxes, labels) if label == 'table column']
rows = [bbox for bbox, label in zip(bboxes, labels) if label == 'table row']
column_headers = [bbox for bbox, label in zip(bboxes, labels) if label == 'table column header']
projected_row_headers = [bbox for bbox, label in zip(bboxes, labels) if label == 'table projected row header']
spanning_cells = [bbox for bbox, label in zip(bboxes, labels) if label == 'table spanning cell']

for column_num, bbox in enumerate(columns):
if column_num % 2 == 0:
linewidth = 2
alpha = 0.6
facecolor = 'none'
edgecolor = 'red'
hatch = '..'
else:
linewidth = 2
alpha = 0.15
facecolor = (1, 0, 0)
edgecolor = (0.8, 0, 0)
hatch = ''
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor, facecolor=facecolor, linestyle="-",
hatch=hatch, alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='red', facecolor='none', linestyle="-",
alpha=0.8)
ax.add_patch(rect)

for row_num, bbox in enumerate(rows):
if row_num % 2 == 1:
linewidth = 2
alpha = 0.6
edgecolor = 'blue'
facecolor = 'none'
hatch = '....'
else:
linewidth = 2
alpha = 0.1
facecolor = (0, 0, 1)
edgecolor = (0, 0, 0.8)
hatch = ''
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor, facecolor=facecolor, linestyle="-",
hatch=hatch, alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='blue', facecolor='none', linestyle="-",
alpha=0.8)
ax.add_patch(rect)

for bbox in column_headers:
linewidth = 3
alpha = 0.3
facecolor = (1, 0, 0.75) #(0.5, 0.45, 0.25)
edgecolor = (1, 0, 0.75) #(1, 0.9, 0.5)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=facecolor, alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor,facecolor='none',linestyle="-", hatch='///')
ax.add_patch(rect)

for bbox in projected_row_headers:
facecolor = (1, 0.9, 0.5) #(0, 0.75, 1) #(0, 0.4, 0.4)
edgecolor = (1, 0.9, 0.5) #(0, 0.7, 0.95)
alpha = 0.35
linewidth = 3
linestyle="--"
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=facecolor, alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor=edgecolor,facecolor='none',linestyle=linestyle)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0,
edgecolor=edgecolor,facecolor='none',linestyle=linestyle, hatch='\\\\')
ax.add_patch(rect)

for bbox in spanning_cells:
color = (0.2, 0.5, 0.2) #(0, 0.4, 0.4)
alpha = 0.4
linewidth = 4
linestyle="-"
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor='none',facecolor=color, alpha=alpha)
ax.add_patch(rect)
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth,
edgecolor=color,facecolor='none',linestyle=linestyle) # hatch='//'
ax.add_patch(rect)

table_bbox = tables[0]
plt.xlim([table_bbox[0]-5, table_bbox[2]+5])
plt.ylim([table_bbox[3]+5, table_bbox[1]-5])
plt.xticks([], [])
plt.yticks([], [])

legend_elements = [Patch(facecolor=(0.9, 0.9, 1), edgecolor=(0, 0, 0.8),
label='Row (odd)'),
Patch(facecolor=(1, 1, 1), edgecolor=(0, 0, 0.8),
label='Row (even)', hatch='...'),
Patch(facecolor=(1, 1, 1), edgecolor=(0.8, 0, 0),
label='Column (odd)', hatch='...'),
Patch(facecolor=(1, 0.85, 0.85), edgecolor=(0.8, 0, 0),
label='Column (even)'),
Patch(facecolor=(0.68, 0.8, 0.68), edgecolor=(0.2, 0.5, 0.2),
label='Spanning cell'),
Patch(facecolor=(1, 0.7, 0.925), edgecolor=(1, 0, 0.75),
label='Column header', hatch='///'),
Patch(facecolor=(1, 0.965, 0.825), edgecolor=(1, 0.9, 0.5),
label='Projected row header', hatch='\\\\')]
ax.legend(handles=legend_elements, bbox_to_anchor=(0, -0.02), loc='upper left', borderaxespad=0,
fontsize=16, ncol=4)
plt.gcf().set_size_inches(20, 20)
plt.axis('off')

plt.savefig(save_filepath, bbox_inches='tight', dpi=150)
plt.show()
plt.close()
except:
traceback.print_exc()

def visualize(args, target, pred_logits, pred_bboxes):
img_filepath = target["img_path"]
img_filename = img_filepath.split("/")[-1]
Expand All @@ -501,6 +647,9 @@ def visualize(args, target, pred_logits, pred_bboxes):
pred_bboxes = pred_bboxes.detach().cpu()
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]

pred_label_names = [structure_class_names[x] for x in pred_labels]
#visualize_better(pred_label_names, pred_bboxes, target, args.debug_save_dir)
#return
fig,ax = plt.subplots(1)
ax.imshow(img, interpolation='lanczos')

Expand Down Expand Up @@ -550,6 +699,9 @@ def visualize(args, target, pred_logits, pred_bboxes):

for cell in pred_cells:
bbox = cell['bbox']
column_num = cell['column_nums'][0]
row_num = cell['row_nums'][0]

if cell['header']:
alpha = 0.3
else:
Expand All @@ -564,7 +716,7 @@ def visualize(args, target, pred_logits, pred_bboxes):
rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1,
edgecolor="magenta",facecolor='none',linestyle="--")
ax.add_patch(rect)

plt.text(bbox[0], bbox[1], f"{row_num},{column_num}")
fig.set_size_inches((15, 15))
plt.axis('off')
plt.savefig(cells_out_filepath, bbox_inches='tight', dpi=100)
Expand Down Expand Up @@ -611,7 +763,7 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic

loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict

#print(loss_dict)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {k: v * weight_dict[k]
Expand All @@ -633,6 +785,7 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic
pred_logits_collection += list(outputs['pred_logits'].detach().cpu())
pred_bboxes_collection += list(outputs['pred_boxes'].detach().cpu())

#print(pred_bboxes_collection)
for target in targets:
for k, v in target.items():
if not k == 'img_path':
Expand All @@ -642,12 +795,19 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic
img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json"))
target["img_words_path"] = img_words_filepath
targets_collection += targets

# print(targets_collection)
if batch_num % args.eval_step == 0 or batch_num == num_batches:
arguments = zip(targets_collection, pred_logits_collection, pred_bboxes_collection,
repeat(args.mode))
with multiprocessing.Pool(args.eval_pool_size) as pool:
metrics = pool.starmap_async(eval_tsr_sample, arguments).get()
#print(list(arguments)[0])
metrics = map(eval_tsr_sample,targets_collection,pred_logits_collection,pred_bboxes_collection,repeat(args.mode))
#print(metrics)
#print(list(arguments))
#metrics = map(eval_tsr_sample, list(arguments)[0], list(arguments)[1], list(arguments)[2], list(arguments)[3])
#print("metrics:",metrics)
#with multiprocessing.Pool(args.eval_pool_size) as pool:
# metrics = pool.starmap(eval_tsr_sample, arguments).get()
# print(metrics)
tsr_metrics += metrics
pred_logits_collection = []
pred_bboxes_collection = []
Expand Down
2 changes: 2 additions & 0 deletions src/grits.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def grits_top(true_relative_span_grid, pred_relative_span_grid):
relative to the current grid cell location, in grid coordinate units.
Note that for a non-spanning cell this will always be [0, 0, 1, 1].
"""
#print(true_relative_span_grid)
#print(pred_relative_span_grid)
return factored_2dmss(true_relative_span_grid,
pred_relative_span_grid,
reward_function=iou)
Expand Down
Loading

0 comments on commit 191363b

Please sign in to comment.