From 191363bde81e652de719c2f1505cc4fef32bc724 Mon Sep 17 00:00:00 2001 From: Ayush Anand Date: Sun, 5 May 2024 07:00:08 +0530 Subject: [PATCH] Update TableTransformer Code with PostProcessing Fix --- detr/util/box_ops.py | 4 +- scripts/view_annotations.py | 2 +- src/api.py | 54 +++++++++++ src/eval.py | 178 ++++++++++++++++++++++++++++++++++-- src/grits.py | 2 + src/inference.py | 80 ++++++++++++++-- src/main.py | 3 +- src/postprocess.py | 20 +++- src/table_datasets.py | 15 ++- 9 files changed, 331 insertions(+), 27 deletions(-) create mode 100644 src/api.py diff --git a/detr/util/box_ops.py b/detr/util/box_ops.py index 9c088e5..18c4614 100644 --- a/detr/util/box_ops.py +++ b/detr/util/box_ops.py @@ -48,8 +48,8 @@ def generalized_box_iou(boxes1, boxes2): """ # degenerate boxes gives inf / nan results # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + #assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + #assert (boxes2[:, 2:] >= boxes2[:, :2]).all() iou, union = box_iou(boxes1, boxes2) lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) diff --git a/scripts/view_annotations.py b/scripts/view_annotations.py index d9249f0..913e59c 100644 --- a/scripts/view_annotations.py +++ b/scripts/view_annotations.py @@ -96,7 +96,7 @@ def main(): print(filename) try: xml_filepath = os.path.join(data_directory, split, filename) - img_filepath = xml_filepath.replace(split, "images").replace(".xml", ".jpg") + img_filepath = xml_filepath.replace(split, "images").replace(".xml", ".png") bboxes, labels = read_pascal_voc(xml_filepath) img = Image.open(img_filepath) diff --git a/src/api.py b/src/api.py new file mode 100644 index 0000000..1bf8f96 --- /dev/null +++ b/src/api.py @@ -0,0 +1,54 @@ +from inference import TableExtractionPipeline +from PIL import Image, ImageDraw +from fastapi import FastAPI, UploadFile, Body, File, Depends, Form +from pydantic import BaseModel, model_validator +from typing import List +from pdf2image import convert_from_bytes +import numpy as np +from io import BytesIO +import json +import logging +logging.basicConfig(filename='../../logs/fastapi_app.log', level=logging.DEBUG, + format='%(asctime)s %(levelname)s %(name)s %(threadName)s : %(message)s') + + +app = FastAPI() +logging.info("Initialised FastAPI interface.") + +pipe = TableExtractionPipeline(str_config_path='structure_config.json', str_model_path='../../model.pth', str_device='cuda') +logging.info("TSR pipeline initialised.") + + +@app.get("/") +async def return_cells(message: str = Form(...), files: List[UploadFile]=File(...)): + file_dict = {} + message = json.loads(message) + for f in files: + file_dict[f.filename] = BytesIO(await f.read()) + + logging.info("GET request recieved at the endpoint") + response = [] + for bbox_result in message["bbox_result"]: + pdf_path = bbox_result["pdf_file"] + + images = convert_from_bytes(file_dict[pdf_path].read(), dpi=300) + img = images[bbox_result["page_num"]-1].convert('RGB') + logging.info("PDF file loaded") + print(img.size) + for i, bbox in enumerate(bbox_result["bbox"]): + img_cropped = img.crop((bbox[0], bbox[1], bbox[2], bbox[3])) + draw = ImageDraw.Draw(img_cropped) + extracted_table = pipe.recognize(img_cropped, f"table_{bbox_result['page_num']}_{i}", [], out_cells=True) + for cell in extracted_table["cells"][0]: + bbx = cell["bbox"] + print(cell) + draw.rectangle(bbx, outline="red", width=10) + + #extracted_table = pipe.recognize(img_cropped, f"table_{bbox_result['page_num']}_{i}", [], out_cells=True) + #for cell in extracted_table["cells"][0]: + # bbx = cell["bbox"] + # draw.rectangle(bbx, outline="blue", width=10) + img_cropped.save("cropped.png") + response.append(extracted_table["cells"]) + logging.info(f"table_{bbox_result['page_num']}_{i} extracted successfully") + return response diff --git a/src/eval.py b/src/eval.py index e3a0565..6df4d7e 100644 --- a/src/eval.py +++ b/src/eval.py @@ -20,6 +20,9 @@ import matplotlib.patches as patches from fitz import Rect from PIL import Image +import traceback +import matplotlib.patches as patches +from matplotlib.patches import Patch sys.path.append("../detr") import util.misc as utils @@ -278,9 +281,11 @@ def compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells, # Compute grids/matrices for comparison true_relspan_grid = np.array(grits.cells_to_relspan_grid(true_cells)) + print("true:",true_relspan_grid.shape) true_bbox_grid = np.array(grits.cells_to_grid(true_cells, key='bbox')) true_text_grid = np.array(grits.cells_to_grid(true_cells, key='cell_text'), dtype=object) pred_relspan_grid = np.array(grits.cells_to_relspan_grid(pred_cells)) + print("pred:",pred_relspan_grid.shape) pred_bbox_grid = np.array(grits.cells_to_grid(pred_cells, key='bbox')) pred_text_grid = np.array(grits.cells_to_grid(pred_cells, key='cell_text'), dtype=object) @@ -462,29 +467,170 @@ def eval_tsr_sample(target, pred_logits, pred_bboxes, mode): img_words_filepath = target["img_words_path"] with open(img_words_filepath, 'r') as f: true_page_tokens = json.load(f) - + true_table_structures, true_cells, _ = objects_to_cells(true_bboxes, true_labels, true_scores, true_page_tokens, structure_class_names, structure_class_thresholds, structure_class_map) - + #print(true_table_structures) + #print(torch.max(pred_logits.softmax(-1), -1)) m = pred_logits.softmax(-1).max(-1) + #print("m") pred_labels = list(m.indices.detach().cpu().numpy()) + #print(pred_labels) pred_scores = list(m.values.detach().cpu().numpy()) + #print(pred_scores) pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, true_img_size)] + #print(pred_bboxes) _, pred_cells, _ = objects_to_cells(pred_bboxes, pred_labels, pred_scores, true_page_tokens, structure_class_names, structure_class_thresholds, structure_class_map) - + #print(pred_cells) metrics = compute_metrics(mode, true_bboxes, true_labels, true_scores, true_cells, pred_bboxes, pred_labels, pred_scores, pred_cells) statistics = compute_statistics(true_table_structures, true_cells) metrics.update(statistics) metrics['id'] = target["img_path"].split('/')[-1].split('.')[0] - + print(metrics) return metrics +def visualize_better(labels, bboxes, target, debug_dir): + try: + img_filepath = target["img_path"] + img_filename = img_filepath.split("/")[-1] + + bboxes_out_filename = img_filename.replace(".jpg", "_bboxes.jpg") + save_filepath = os.path.join(debug_dir, bboxes_out_filename) + + img = Image.open(img_filepath) + + ax = plt.gca() + ax.imshow(img, interpolation="lanczos") + plt.gcf().set_size_inches((24, 24)) + + tables = [bbox for bbox, label in zip(bboxes, labels) if label == 'table'] + columns = [bbox for bbox, label in zip(bboxes, labels) if label == 'table column'] + rows = [bbox for bbox, label in zip(bboxes, labels) if label == 'table row'] + column_headers = [bbox for bbox, label in zip(bboxes, labels) if label == 'table column header'] + projected_row_headers = [bbox for bbox, label in zip(bboxes, labels) if label == 'table projected row header'] + spanning_cells = [bbox for bbox, label in zip(bboxes, labels) if label == 'table spanning cell'] + + for column_num, bbox in enumerate(columns): + if column_num % 2 == 0: + linewidth = 2 + alpha = 0.6 + facecolor = 'none' + edgecolor = 'red' + hatch = '..' + else: + linewidth = 2 + alpha = 0.15 + facecolor = (1, 0, 0) + edgecolor = (0.8, 0, 0) + hatch = '' + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, + edgecolor=edgecolor, facecolor=facecolor, linestyle="-", + hatch=hatch, alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='red', facecolor='none', linestyle="-", + alpha=0.8) + ax.add_patch(rect) + + for row_num, bbox in enumerate(rows): + if row_num % 2 == 1: + linewidth = 2 + alpha = 0.6 + edgecolor = 'blue' + facecolor = 'none' + hatch = '....' + else: + linewidth = 2 + alpha = 0.1 + facecolor = (0, 0, 1) + edgecolor = (0, 0, 0.8) + hatch = '' + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, + edgecolor=edgecolor, facecolor=facecolor, linestyle="-", + hatch=hatch, alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='blue', facecolor='none', linestyle="-", + alpha=0.8) + ax.add_patch(rect) + + for bbox in column_headers: + linewidth = 3 + alpha = 0.3 + facecolor = (1, 0, 0.75) #(0.5, 0.45, 0.25) + edgecolor = (1, 0, 0.75) #(1, 0.9, 0.5) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='none',facecolor=facecolor, alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, + edgecolor=edgecolor,facecolor='none',linestyle="-", hatch='///') + ax.add_patch(rect) + + for bbox in projected_row_headers: + facecolor = (1, 0.9, 0.5) #(0, 0.75, 1) #(0, 0.4, 0.4) + edgecolor = (1, 0.9, 0.5) #(0, 0.7, 0.95) + alpha = 0.35 + linewidth = 3 + linestyle="--" + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='none',facecolor=facecolor, alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, + edgecolor=edgecolor,facecolor='none',linestyle=linestyle) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, + edgecolor=edgecolor,facecolor='none',linestyle=linestyle, hatch='\\\\') + ax.add_patch(rect) + + for bbox in spanning_cells: + color = (0.2, 0.5, 0.2) #(0, 0.4, 0.4) + alpha = 0.4 + linewidth = 4 + linestyle="-" + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor='none',facecolor=color, alpha=alpha) + ax.add_patch(rect) + rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, + edgecolor=color,facecolor='none',linestyle=linestyle) # hatch='//' + ax.add_patch(rect) + + table_bbox = tables[0] + plt.xlim([table_bbox[0]-5, table_bbox[2]+5]) + plt.ylim([table_bbox[3]+5, table_bbox[1]-5]) + plt.xticks([], []) + plt.yticks([], []) + + legend_elements = [Patch(facecolor=(0.9, 0.9, 1), edgecolor=(0, 0, 0.8), + label='Row (odd)'), + Patch(facecolor=(1, 1, 1), edgecolor=(0, 0, 0.8), + label='Row (even)', hatch='...'), + Patch(facecolor=(1, 1, 1), edgecolor=(0.8, 0, 0), + label='Column (odd)', hatch='...'), + Patch(facecolor=(1, 0.85, 0.85), edgecolor=(0.8, 0, 0), + label='Column (even)'), + Patch(facecolor=(0.68, 0.8, 0.68), edgecolor=(0.2, 0.5, 0.2), + label='Spanning cell'), + Patch(facecolor=(1, 0.7, 0.925), edgecolor=(1, 0, 0.75), + label='Column header', hatch='///'), + Patch(facecolor=(1, 0.965, 0.825), edgecolor=(1, 0.9, 0.5), + label='Projected row header', hatch='\\\\')] + ax.legend(handles=legend_elements, bbox_to_anchor=(0, -0.02), loc='upper left', borderaxespad=0, + fontsize=16, ncol=4) + plt.gcf().set_size_inches(20, 20) + plt.axis('off') + + plt.savefig(save_filepath, bbox_inches='tight', dpi=150) + plt.show() + plt.close() + except: + traceback.print_exc() + def visualize(args, target, pred_logits, pred_bboxes): img_filepath = target["img_path"] img_filename = img_filepath.split("/")[-1] @@ -501,6 +647,9 @@ def visualize(args, target, pred_logits, pred_bboxes): pred_bboxes = pred_bboxes.detach().cpu() pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] + pred_label_names = [structure_class_names[x] for x in pred_labels] + #visualize_better(pred_label_names, pred_bboxes, target, args.debug_save_dir) + #return fig,ax = plt.subplots(1) ax.imshow(img, interpolation='lanczos') @@ -550,6 +699,9 @@ def visualize(args, target, pred_logits, pred_bboxes): for cell in pred_cells: bbox = cell['bbox'] + column_num = cell['column_nums'][0] + row_num = cell['row_nums'][0] + if cell['header']: alpha = 0.3 else: @@ -564,7 +716,7 @@ def visualize(args, target, pred_logits, pred_bboxes): rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=1, edgecolor="magenta",facecolor='none',linestyle="--") ax.add_patch(rect) - + plt.text(bbox[0], bbox[1], f"{row_num},{column_num}") fig.set_size_inches((15, 15)) plt.axis('off') plt.savefig(cells_out_filepath, bbox_inches='tight', dpi=100) @@ -611,7 +763,7 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict - + #print(loss_dict) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) loss_dict_reduced_scaled = {k: v * weight_dict[k] @@ -633,6 +785,7 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic pred_logits_collection += list(outputs['pred_logits'].detach().cpu()) pred_bboxes_collection += list(outputs['pred_boxes'].detach().cpu()) + #print(pred_bboxes_collection) for target in targets: for k, v in target.items(): if not k == 'img_path': @@ -642,12 +795,19 @@ def evaluate(args, model, criterion, postprocessors, data_loader, base_ds, devic img_words_filepath = os.path.join(args.table_words_dir, img_filename.replace(".jpg", "_words.json")) target["img_words_path"] = img_words_filepath targets_collection += targets - + # print(targets_collection) if batch_num % args.eval_step == 0 or batch_num == num_batches: arguments = zip(targets_collection, pred_logits_collection, pred_bboxes_collection, repeat(args.mode)) - with multiprocessing.Pool(args.eval_pool_size) as pool: - metrics = pool.starmap_async(eval_tsr_sample, arguments).get() + #print(list(arguments)[0]) + metrics = map(eval_tsr_sample,targets_collection,pred_logits_collection,pred_bboxes_collection,repeat(args.mode)) + #print(metrics) + #print(list(arguments)) + #metrics = map(eval_tsr_sample, list(arguments)[0], list(arguments)[1], list(arguments)[2], list(arguments)[3]) + #print("metrics:",metrics) + #with multiprocessing.Pool(args.eval_pool_size) as pool: + # metrics = pool.starmap(eval_tsr_sample, arguments).get() + # print(metrics) tsr_metrics += metrics pred_logits_collection = [] pred_bboxes_collection = [] diff --git a/src/grits.py b/src/grits.py index c9d4828..1f5720d 100644 --- a/src/grits.py +++ b/src/grits.py @@ -383,6 +383,8 @@ def grits_top(true_relative_span_grid, pred_relative_span_grid): relative to the current grid cell location, in grid coordinate units. Note that for a non-spanning cell this will always be [0, 0, 1, 1]. """ + #print(true_relative_span_grid) + #print(pred_relative_span_grid) return factored_2dmss(true_relative_span_grid, pred_relative_span_grid, reward_function=iou) diff --git a/src/inference.py b/src/inference.py index 564dd76..0fdddd5 100644 --- a/src/inference.py +++ b/src/inference.py @@ -325,10 +325,23 @@ def objects_to_structures(objects, tokens, class_thresholds): if iob(obj['bbox'], header_obj['bbox']) >= 0.5: obj['column header'] = True + row_rect = Rect() + for obj in rows: + row_rect.include_rect(obj['bbox']) + column_rect = Rect() + for obj in columns: + column_rect.include_rect(obj['bbox']) + table['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]] + table['bbox'] = table['row_column_bbox'] + + # Process the rows and columns into a complete segmented table + columns = postprocess.align_columns(columns, table['row_column_bbox']) + rows = postprocess.align_rows(rows, table['row_column_bbox']) # Refine table structures rows = postprocess.refine_rows(rows, table_tokens, class_thresholds['table row']) columns = postprocess.refine_columns(columns, table_tokens, class_thresholds['table column']) + # Shrink table bbox to just the total height of the rows # and the total width of the columns row_rect = Rect() @@ -352,10 +365,53 @@ def objects_to_structures(objects, tokens, class_thresholds): if len(rows) > 0 and len(columns) > 1: structure = refine_table_structure(structure, class_thresholds) + structure['columns'] = fill_column_gaps(structure['columns']) + structure['rows'] = fill_row_gaps(structure['rows']) table_structures.append(structure) return table_structures +def fill_column_gaps(columns): + if len(columns) == 0: + return columns + + mean_width = 0 + for obj in columns: + mean_width += abs(obj['bbox'][2]-obj['bbox'][1]) + mean_width /= len(columns) + insertions = [] + + for i in range(len(columns)-1): + gap_width = columns[i+1]['bbox'][0]-columns[i]['bbox'][2] + if gap_width > (mean_width/3): + new_col = {'label':'table column', 'score': 0.99} + new_col['bbox'] = [columns[i]['bbox'][2], columns[i]['bbox'][1], columns[i+1]['bbox'][0],columns[i]['bbox'][3]] + insertions.append((i+1, new_col)) + + for i in insertions: + columns.insert(i[0], i[1]) + return columns + +def fill_row_gaps(rows): + if len(rows) == 0: + return rows + mean_width = 0 + for obj in rows: + mean_width += abs(obj['bbox'][3]-obj['bbox'][1]) + mean_width /= len(rows) + insertions = [] + for i in range(len(rows)-1): + gap_width = rows[i+1]['bbox'][1]-rows[i]['bbox'][3] + if gap_width > (mean_width/3): + new_row = {'label':'table row', 'score': 0.99, 'column header': False} + new_row['bbox'] = [rows[i]['bbox'][0], rows[i]['bbox'][3], rows[i]['bbox'][2], rows[i+1]['bbox'][1]] + insertions.append((i+1, new_row)) + + + for i in insertions: + rows.insert(i[0], i[1]) + return rows + def structure_to_cells(table_structure, tokens): """ Assuming the row, column, spanning cell, and header bounding boxes have @@ -434,9 +490,10 @@ def structure_to_cells(table_structure, tokens): confidence_score = 0 # Dilate rows and columns before final extraction - #dilated_columns = fill_column_gaps(columns, table_bbox) + + #dilated_columns = fill_column_gaps(columns) dilated_columns = columns - #dilated_rows = fill_row_gaps(rows, table_bbox) + #dilated_rows = fill_row_gaps(rows) dilated_rows = rows for cell in cells: column_rect = Rect() @@ -620,9 +677,12 @@ def visualize_cells(img, cells, out_path): plt.imshow(img, interpolation="lanczos") plt.gcf().set_size_inches(20, 20) ax = plt.gca() + print("cells:",cells) for cell in cells: bbox = cell['bbox'] + column_num = cell['column_nums'][0] + row_num = cell['row_nums'][0] if cell['column header']: facecolor = (1, 0, 0.45) @@ -652,6 +712,7 @@ def visualize_cells(img, cells, out_path): rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) ax.add_patch(rect) + plt.text(bbox[0], bbox[1], f"{row_num},{column_num}") plt.xticks([], []) plt.yticks([], []) @@ -753,7 +814,7 @@ def detect(self, img, tokens=None, out_objects=True, out_crops=False, crop_paddi return out_formats - def recognize(self, img, tokens=None, out_objects=False, out_cells=False, + def recognize(self, img, img_file, tokens=None, out_objects=False, out_cells=False, out_html=False, out_csv=False): out_formats = {} if self.str_model is None: @@ -796,7 +857,7 @@ def recognize(self, img, tokens=None, out_objects=False, out_cells=False, if out_csv: tables_csvs = [cells_to_csv(cells) for cells in tables_cells] out_formats['csv'] = tables_csvs - + # output_result("cells", [tables_cells], args, img, img_file) return out_formats def extract(self, img, tokens=None, out_objects=True, out_crops=False, out_cells=False, @@ -828,7 +889,7 @@ def output_result(key, val, args, img, img_file): with open(os.path.join(args.out_dir, out_file), 'w') as f: json.dump(val, f) if args.visualize: - out_file = img_file.replace(".jpg", "_fig_tables.jpg") + out_file = img_file.replace(".png", "_fig_tables.png") out_path = os.path.join(args.out_dir, out_file) visualize_detected_tables(img, val, out_path) elif not key == 'image' and not key == 'tokens': @@ -842,13 +903,13 @@ def output_result(key, val, args, img, img_file): with open(os.path.join(args.out_dir, out_words_file), 'w') as f: json.dump(cropped_table['tokens'], f) elif key == 'cells': - out_file = img_file.replace(".jpg", "_{}_objects.json".format(idx)) + out_file = img_file.replace(".png", "_{}_objects.json".format(idx)) with open(os.path.join(args.out_dir, out_file), 'w') as f: json.dump(elem, f) if args.verbose: print(elem) if args.visualize: - out_file = img_file.replace(".jpg", "_fig_cells.jpg") + out_file = img_file.replace(".png", "_fig_cells.png") out_path = os.path.join(args.out_dir, out_file) visualize_cells(img, elem, out_path) else: @@ -884,7 +945,7 @@ def main(): for count, img_file in enumerate(img_files): print("({}/{})".format(count+1, num_files)) img_path = os.path.join(args.image_dir, img_file) - img = Image.open(img_path) + img = Image.open(img_path).convert('RGB') print("Image loaded.") if not args.words_dir is None: @@ -910,11 +971,12 @@ def main(): tokens = [] if args.mode == 'recognize': - extracted_table = pipe.recognize(img, tokens, out_objects=args.objects, out_cells=args.csv, + extracted_table = pipe.recognize(img, img_file, tokens, out_objects=args.objects, out_cells=args.csv, out_html=args.html, out_csv=args.csv) print("Table(s) recognized.") for key, val in extracted_table.items(): + print("Key:",key) output_result(key, val, args, img, img_file) if args.mode == 'detect': diff --git a/src/main.py b/src/main.py index 74cd13c..b175eef 100644 --- a/src/main.py +++ b/src/main.py @@ -311,13 +311,14 @@ def train(args, model, criterion, postprocessors, device): max_batches_per_epoch=max_batches_per_epoch, print_freq=1000) print("Epoch completed in ", datetime.now() - epoch_timing) - + lr_scheduler.step() pubmed_stats, coco_evaluator = evaluate(model, criterion, postprocessors, data_loader_val, dataset_val, device, None) + print("pubmed: AP50: {:.3f}, AP75: {:.3f}, AP: {:.3f}, AR: {:.3f}". format(pubmed_stats['coco_eval_bbox'][1], pubmed_stats['coco_eval_bbox'][2], diff --git a/src/postprocess.py b/src/postprocess.py index 25feaee..409822e 100644 --- a/src/postprocess.py +++ b/src/postprocess.py @@ -114,6 +114,19 @@ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, for column in columns: column['page'] = page_num + row_rect = Rect() + for obj in rows: + row_rect.include_rect(obj['bbox']) + column_rect = Rect() + for obj in columns: + column_rect.include_rect(obj['bbox']) + table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]] + table_object['bbox'] = table_object['row_column_bbox'] + + # Process the rows and columns into a complete segmented table + columns = align_columns(columns, table_object['row_column_bbox']) + rows = align_rows(rows, table_object['row_column_bbox']) + #Refine table structures rows = refine_rows(rows, tokens_in_table, class_thresholds['table row']) columns = refine_columns(columns, tokens_in_table, class_thresholds['table column']) @@ -125,6 +138,7 @@ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, row_rect.include_rect(obj['bbox']) column_rect = Rect() for obj in columns: + print(obj) column_rect.include_rect(obj['bbox']) table_object['row_column_bbox'] = [column_rect[0], row_rect[1], column_rect[2], row_rect[3]] table_object['bbox'] = table_object['row_column_bbox'] @@ -140,7 +154,7 @@ def objects_to_table_structures(table_object, objects_in_table, tokens_in_table, if len(rows) > 0 and len(columns) > 1: table_structures = refine_table_structures(table_object['bbox'], table_structures, tokens_in_table, class_thresholds) - + return table_structures @@ -459,7 +473,7 @@ def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_hi num_objects = len(objects) suppression = [False for obj in objects] - + for object2_num in range(1, num_objects): object2_rect = Rect(objects[object2_num]['bbox']) object2_area = object2_rect.get_area() @@ -717,6 +731,8 @@ def table_structure_to_cells(table_structures, table_spans, table_bbox): cell['subcell'] = False for supercell in supercells: supercell_rect = Rect(list(supercell['bbox'])) + if cell_rect.get_area() == 0: + continue if (supercell_rect.intersect(cell_rect).get_area() / cell_rect.get_area()) > 0.5: cell['subcell'] = True diff --git a/src/table_datasets.py b/src/table_datasets.py index 1fbe017..cbc3d71 100644 --- a/src/table_datasets.py +++ b/src/table_datasets.py @@ -507,12 +507,13 @@ def __init__(self, root, transforms=None, max_size=None, do_crop=True, make_coco self.transforms = transforms self.do_crop=do_crop self.make_coco = make_coco - self.image_extension = image_extension + self.image_extension = ".png" self.include_eval = include_eval self.class_map = class_map self.class_list = list(class_map) self.class_set = set(class_map.values()) self.class_set.remove(class_map['no object']) + #self.make_coco = True try: @@ -522,6 +523,7 @@ def __init__(self, root, transforms=None, max_size=None, do_crop=True, make_coco except: lines = os.listdir(root) xml_page_ids = set([f.strip().replace(".xml", "") for f in lines if f.strip().endswith(".xml")]) + #print(xml_page_ids) image_directory = os.path.join(root, "..", "images") try: @@ -530,8 +532,11 @@ def __init__(self, root, transforms=None, max_size=None, do_crop=True, make_coco except: lines = os.listdir(image_directory) png_page_ids = set([f.strip().replace(self.image_extension, "") for f in lines if f.strip().endswith(self.image_extension)]) - - self.page_ids = sorted(xml_page_ids.intersection(png_page_ids)) + #print(png_page_ids) + + self.page_ids = list(sorted(xml_page_ids.intersection(png_page_ids))) + #print(self.page_ids) + #self.page_ids = png_page_ids if not max_size is None: random.shuffle(self.page_ids) self.page_ids = self.page_ids[:max_size] @@ -555,8 +560,10 @@ def __init__(self, root, transforms=None, max_size=None, do_crop=True, make_coco self.dataset['images'] = [{'id': idx} for idx, _ in enumerate(self.page_ids)] self.dataset['annotations'] = [] ann_id = 0 + print(self.page_ids) for image_id, page_id in enumerate(self.page_ids): annot_path = os.path.join(self.root, page_id + ".xml") + print(annot_path) bboxes, labels = read_pascal_voc(annot_path, class_map=self.class_map) # Reduce class set @@ -609,6 +616,8 @@ def createIndex(self): self.catToImgs = catToImgs self.imgs = imgs self.cats = cats + #print(imgs) + #print(anns) def __getitem__(self, idx): # load images ad masks