diff --git a/README.md b/README.md new file mode 100644 index 0000000..f476f49 --- /dev/null +++ b/README.md @@ -0,0 +1,38 @@ +# INSTRUCTIONS FOR THE STANDALONE SCRIPTS + +## Installation + +The scripts require [python >= 3.8](https://www.python.org/downloads/release/python-380/) to run. +We will create a fresh virtualenvironment in which to install all required packages. +```sh +mkvirtualenv -p /usr/bin/python3 DUDEeval +``` + +Using poetry and the readily defined pyproject.toml, we will install all required packages +```sh +workon DUDEeval +pip3 install poetry +poetry install +``` + +## Way of working + +Open a terminal in the directory and run the command: +```python +python3 evaluate_submission.py -g=gt/DUDE_demo_gt.json -s=submissions/DUDE_demo_submission_perfect.json +``` + +parameters: +-g: Path of the Ground Truth file. The Ground Truth file is the one provided for the competition. You will be able to get it on the Downloads page of the Task in the Competition portal. +-s: Path of your method's results file. + +Optional parameters: +-t: ANLS threshold. By default 0.5 is used. This can be used to check the tolerance to OCR errors. See Scene-Text VQA paper for more info. +-a: Boolean to get the scores break down by types. The ground truth file is required to have such information (currently is not available to the public). +-o: Path to a directory where to copy the file 'results.json' that contains per-sample results. +-c: Measure ECE for scoring calibration of answer confidences + +### Bonus: + +To pretty print a JSON file you can use: +`python -m json.tool file.json` diff --git a/evaluate_submission.py b/evaluate_submission.py new file mode 100644 index 0000000..85d5a20 --- /dev/null +++ b/evaluate_submission.py @@ -0,0 +1,373 @@ +import os, json +import logging +import argparse +import numpy as np +from munkres import Munkres, print_matrix, make_cost_matrix +import evaluate as HF_evaluate + +question_ids_to_exclude = [] + +answer_types = { + "abstractive": "Abstractive", + "extractive": "Extractive", + "not answerable": "Not Answerable", +} + + +def save_json(file_path, data): + with open(file_path, "w+") as json_file: + json.dump(data, json_file) + + +def levenshtein_distance(s1, s2): + if len(s1) > len(s2): + s1, s2 = s2, s1 + + distances = range(len(s1) + 1) + for i2, c2 in enumerate(s2): + distances_ = [i2 + 1] + for i1, c1 in enumerate(s1): + if c1 == c2: + distances_.append(distances[i1]) + else: + distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) + distances = distances_ + return distances[-1] + + +def parse_answers(pred_answers): + if len(pred_answers) == 0: + logging.warning("Mistaken unanswerable prediction") + pred_answers = "" + + if isinstance(pred_answers, list): + if len(pred_answers) > 1: + logging.warning("Mistaken list prediction, assuming first") + pred_answers = pred_answers[0] + return pred_answers + + +def get_NLS(gt_answers, pred_answers, threshold): + values = [] + + pred_answers = parse_answers(pred_answers) + + for answer in gt_answers: + # preprocess both the answers - gt and prediction + gt_answer = " ".join(answer.strip().lower().split()) + det_answer = " ".join(pred_answers.strip().lower().split()) + + dist = levenshtein_distance(gt_answer, det_answer) + length = max(len(answer.upper()), len(pred_answers.upper())) + values.append(0.0 if length == 0 else float(dist) / float(length)) + + question_result = 1 - min(values) + + if question_result < threshold: + question_result = 0 + + return question_result + + +def get_best_matches_hungarian_munkers(anchor_list, matching_list): + + match_dict = {} + match_matrix = [] + for anchor_item in anchor_list: + NLS_dict = {} + NLS_list = [] + for matching_item in matching_list: + NLS = get_NLS([anchor_item], matching_item, threshold=0.5) + NLS_dict[str(matching_item) + " "] = NLS + NLS_list.append(NLS) + + match_dict[anchor_item] = NLS_dict + match_matrix.append(NLS_list) + + return match_dict, match_matrix + + +def get_NLSL(gt_list, pred_list): + if len(gt_list) < len(pred_list): + anchor_list, matching_list = gt_list, pred_list + + else: + anchor_list, matching_list = pred_list, gt_list + + match_dict, cost_matrix = get_best_matches_hungarian_munkers(anchor_list, matching_list) + num_answers = max(len(set(gt_list)), len(pred_list)) + + m = Munkres() + m_cost_matrix = make_cost_matrix(cost_matrix) + indexes = m.compute(m_cost_matrix) + values = [cost_matrix[row][column] for row, column in indexes] + NLSL = np.sum(values) / num_answers + + return NLSL + + +def validate_data(gtFilePath, submFilePath): + """ + Method validate_data: validates that all files in the results folder are correct (have the correct name contents). + Validates also that there are no missing files in the folder. + If some error detected, the method raises the error + """ + + gtJson = json.load(open(gtFilePath, "rb")) + submJson = json.load(open(submFilePath, "rb")) + + if not "data" in gtJson: + raise Exception("The GT file is not valid (no data key)") + + if not "dataset_name" in gtJson: + raise Exception("The GT file is not valid (no dataset_name key)") + + if gtJson["dataset_name"] != "DUDE Dataset": + raise Exception("The GT file is not valid dataset_name should be DUDE Dataset") + + if isinstance(submJson, list) == False: + raise Exception("The Det file is not valid (root item must be an array)") + + if len(submJson) != len(gtJson["data"]): + raise Exception( + "The Det file is not valid (invalid number of answers. Expected:" + + str(len(gtJson["data"])) + + " Found:" + + str(len(submJson)) + + ")" + ) + + gtQuestions = sorted([str(r["questionId"]) for r in gtJson["data"]]) + res_id_to_index = {str(r["questionId"]): ix for ix, r in enumerate(submJson)} + detQuestions = sorted([str(r["questionId"]) for r in submJson]) + + if (gtQuestions == detQuestions) == False: + print(len(gtQuestions), len(detQuestions)) + print(len(set(gtQuestions).intersection(detQuestions))) + print(gtQuestions[0], detQuestions[0]) + raise Exception("The Det file is not valid. Question IDs must match GT") + + for gtObject in gtJson["data"]: + + try: + q_id = str(gtObject["questionId"]) + res_ix = res_id_to_index[q_id] + + except: + raise Exception( + "The Det file is not valid. Question " + + str(gtObject["questionId"]) + + " not present" + ) + + else: + detObject = submJson[res_ix] + + if not "answers" in detObject: + raise Exception( + "Question " + str(gtObject["questionId"]) + " not valid (no answer key)" + ) + + +def evaluate_method(gtFilePath, submFilePath, evaluationParams): + """ + Method evaluate_method: evaluate method and returns the results + Results. Dictionary with the following values: + - method (required) Global method metrics. Ex: { 'Precision':0.8,'Recall':0.9 } + - samples (optional) Per sample metrics. Ex: {'sample1' : { 'Precision':0.8,'Recall':0.9 } , 'sample2' : { 'Precision':0.8,'Recall':0.9 } + """ + + show_scores_per_question_type = evaluationParams.answer_types + + gtJson = json.load(open(gtFilePath, "rb")) + submJson = json.load(open(submFilePath, "rb")) + + res_id_to_index = {str(r["questionId"]): ix for ix, r in enumerate(submJson)} + + perSampleMetrics = {} + + totalANLS = 0 + row = 0 + + if show_scores_per_question_type: + answerTypeTotalANLS = {x: 0 for x in answer_types.keys()} + answerTypeNumQuestions = {x: 0 for x in answer_types.keys()} + + for gtObject in gtJson["data"]: + q_id = str(gtObject["questionId"]) + res_ix = res_id_to_index[q_id] + detObject = submJson[res_ix] + + if q_id in question_ids_to_exclude: + question_result = 0 + info = "Question EXCLUDED from the result" + + else: + info = "" + + if "list" in gtObject["answers_variants"]: # or multiple predicted? + question_result = get_NLSL(gtObject["answers"], detObject["answers"]) + + else: + question_result = get_NLS( + gtObject["answers"], detObject["answers"], evaluationParams.anls_threshold + ) + + totalANLS += question_result + + if show_scores_per_question_type: + answer_type = gtObject["answer_type"] + answerTypeTotalANLS[answer_type] += question_result + answerTypeNumQuestions[answer_type] += 1 + + perSampleMetrics[str(gtObject["questionId"])] = { + "anls": question_result, + "question": gtObject["question"], + "gt_answer": gtObject["answers"], + "answer_prediction": detObject["answers"], + "answer_confidence": detObject.get("answers_confidence", -1), + "info": info, + } + row = row + 1 + + methodMetrics = { + "anls": 0 + if len(gtJson["data"]) == 0 + else totalANLS / (len(gtJson["data"]) - len(question_ids_to_exclude)) + } + + if evaluationParams.score_calibration: + # from ANLS, determine exact matches based on ANLS threshold + # ECE is calculated as a population statistic + # TODO(Jordy): hashing-based implementation for type-based calculation + y_correct, p_answers = [], [] + for q in perSampleMetrics: + m = perSampleMetrics[q] + y_correct.append(int(m["anls"] >= evaluationParams.anls_threshold)) + confidence = m["answer_confidence"] + if isinstance(confidence, list): + if len(confidence) > 1: + logging.warning("Mistaken list confidences, assuming first") + confidence = confidence[0] + if confidence == -1: # invalid so cannot evaluate ECE + break + p_answers.append(confidence) + + if len(y_correct) == len(perSampleMetrics): # checks all calculations valid + y_correct = [0 if x == 1 else 1 for x in y_correct] #since ECE expects class size vectors + y_correct = np.array(y_correct).astype(int) + p_answers = np.array(p_answers).astype(np.float32) + + metric = HF_evaluate.load("jordyvl/ece") + kwargs = dict( + n_bins=min(len(perSampleMetrics)-1, 100), + scheme="equal-mass" if len(set(p_answers)) != 1 else "equal-range", + bin_range=[0,1], + proxy="upper-edge", + p=1, + detail=False, + ) + + ece_result = metric.compute( + references=y_correct, predictions=np.expand_dims(p_answers, -1), **kwargs + ) + methodMetrics.update(ece_result) + + answer_types_ANLS = {} + + if show_scores_per_question_type: + for answer_type, answer_type_str in answer_types.items(): + answer_types_ANLS[answer_type_str] = ( + 0 + if len(gtJson["data"]) == 0 + else answerTypeTotalANLS[answer_type] / (answerTypeNumQuestions[answer_type]) + ) + + resDict = { + "result": methodMetrics, + "scores_by_types": {"anls_per_answer_type": answer_types_ANLS}, + "per_sample_result": perSampleMetrics, + } + return resDict + + +def display_results(results, show_answer_page_position): + print("\nOverall ANLS: {:1.4f}\n".format(results["result"]["anls"])) + if "ECE" in results["result"]: + print("\nOverall ECE: {:1.4f}\n".format(results["result"]["ECE"])) + + if show_answer_page_position: + print("Answer type \t ANLS") + for answer_type in answer_types.values(): + print( + "{:10s}\t{:1.4f}".format( + answer_type, results["scores_by_types"]["anls_per_answer_type"][answer_type] + ) + ) + + print("") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="DUDE evaluation script.") + + parser.add_argument( + "-g", "--ground_truth", type=str, help="Path of the Ground Truth file.", required=True + ) + parser.add_argument( + "-s", + "--submission_file", + type=str, + help="Path of your method's results file.", + required=True, + ) + + parser.add_argument( + "-t", + "--anls_threshold", + type=float, + default=0.5, + help="ANLS threshold to use (See Scene-Text VQA paper for more info.).", + required=False, + ) + parser.add_argument( + "-a", + "--answer_types", + default=False, + action="store_true", + required=False, + ) + parser.add_argument( + "-o", + "--output", + type=str, + help="Path to a directory where to copy the file 'results.json' that contains per-sample results.", + required=False, + ) + parser.add_argument( + "-c", + "--score_calibration", + default=False, + action="store_true", + required=False, + ) + args = parser.parse_args() + + # Validate the format of ground truth and submission files. + validate_data(args.ground_truth, args.submission_file) + + # Evaluate method + results = evaluate_method(args.ground_truth, args.submission_file, args) + + display_results(results, args.answer_types) + + if args.output: + output_dir = args.output + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + resultsOutputname = os.path.join(output_dir, "results.json") + save_json(resultsOutputname, results) + + print("All results including per-sample result has been correctly saved!") diff --git a/gt/DUDE_demo_gt.json b/gt/DUDE_demo_gt.json new file mode 100644 index 0000000..594405f --- /dev/null +++ b/gt/DUDE_demo_gt.json @@ -0,0 +1,191 @@ +{ + "dataset_name": "DUDE Dataset", + "dataset_version": "0.1", + "data": [ + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_489ae4ece55ed1af97a542845d3ba6d3", + "question": "Is there some type of handwritten text on each page of the document?", + "answers": [ + "Yes" + ], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "abstractive", + "docId": "7afe94621751eb3584a4a9962bb7b1f0", + "data_split": "train" + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_b1816832f28da57fb863a185299e96b2", + "question": "How many instances of handwriting are in this document?", + "answers": [ + "2" + ], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "abstractive", + "docId": "c228ef8e4149d1532e833f76fe6de2d2", + "data_split": "train" + }, + { + "questionId": "ca7c6dcd92fd573c8ad4cb79b0de525a_53fb47e11ac9c8c18ad9e53045d39039", + "question": "Does each page of the document has a logo on it?", + "answers": [ + "No" + ], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "abstractive", + "docId": "ca7c6dcd92fd573c8ad4cb79b0de525a", + "data_split": "train" + }, + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_4a863bed9b999dec79f6181f7a823e4a", + "question": "In what year was the marriage certificate signed/stamped.", + "answers": [ + "1894" + ], + "answers_page_bounding_boxes": [ + [ + { + "left": 185, + "top": 70, + "width": 299, + "height": 74, + "page": 0 + }, + { + "left": 953, + "top": 287, + "width": 138, + "height": 69, + "page": 0 + } + ] + ], + "answers_variants": [], + "answer_type": "extractive", + "docId": "7afe94621751eb3584a4a9962bb7b1f0", + "data_split": "train" + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_8cd51553579ec302b46b4d2837143cf8", + "question": "What tax year is the document for?", + "answers": [ + "1924" + ], + "answers_page_bounding_boxes": [ + [ + { + "left": 1028, + "top": 607, + "width": 205, + "height": 83, + "page": 0 + } + ] + ], + "answers_variants": [], + "answer_type": "extractive", + "docId": "c228ef8e4149d1532e833f76fe6de2d2", + "data_split": "train" + }, + { + "questionId": "95a82e8934b11e4dca41789f4a70caa3_7b58d9ec29ce6f75d82025df8e6aa91d", + "question": "What is the destination of Air France flight #1044?", + "answers": [ + "Moscow Sheremet, Russia - Terminal E - International" + ], + "answers_page_bounding_boxes": [ + [ + { + "left": 962, + "top": 413, + "width": 183, + "height": 223, + "page": 4 + } + ] + ], + "answers_variants": [], + "answer_type": "extractive", + "docId": "95a82e8934b11e4dca41789f4a70caa3", + "data_split": "train" + }, + { + "questionId": "32f61c097aa230d75bcb93248db87785_d966a58b08dfd5ba52166e5fa2f1ff67", + "question": "What percentage of young people don't take out travel insurance when they travel abroad because they think it is too espensive.", + "answers": [ + "42%" + ], + "answers_page_bounding_boxes": [ + [ + { + "left": 1282, + "top": 1465, + "width": 352, + "height": 193, + "page": 0 + } + ] + ], + "answers_variants": [], + "answer_type": "extractive", + "docId": "32f61c097aa230d75bcb93248db87785", + "data_split": "train" + }, + { + "questionId": "9a2f84b54389aa435486a84c89069942_75bdac76d12afee4a07c2a870dcfbbf5", + "question": "What does NARA stand for?", + "answers": [ + "National Archives and Records Administration" + ], + "answers_page_bounding_boxes": [ + [ + { + "left": 252, + "top": 892, + "width": 762, + "height": 61, + "page": 0 + } + ] + ], + "answers_variants": [ + "The National Archives and Records Administration" + ], + "answer_type": "extractive", + "docId": "9a2f84b54389aa435486a84c89069942", + "data_split": "train" + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_9dc0d59810e040b8ec7516667aa289a4", + "question": "who arrived in Hawaii in march 19 1998?", + "answers": [""], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "not answerable", + "docId": "ff1549ddf50c82edd169dee714a38339", + "data_split": "train" + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_4fa9fe07c5b3f0680ef68da40c92e51a", + "question": "What millitary vehicle is the Vice Admiral in the image on page 3 driving?", + "answers": [""], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "not answerable", + "docId": "ff1549ddf50c82edd169dee714a38339", + "data_split": "train" + }, + { + "questionId": "manual_test", + "question": "What are the flight numbers from going to Washington to Moscow?", + "answers": ["AF0055", "AF1044"], + "answers_page_bounding_boxes": [], + "answers_variants": [], + "answer_type": "extractive", + "docId": "ff1549ddf50c82edd169dee714a38339", + "data_split": "train" + } + ] +} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100755 index 0000000..480d89f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[tool.poetry] +name = "DUDEeval" +version = "0.1.0" +description = "standalone evaluation scripts for DUDE" +authors = ["Ruben ","Jordy "] + +readme = 'README.md' # Markdown files are supported +keywords = ['research', 'evaluation', 'VQA', 'calibration'] + +classifiers = [ + "Development Status :: 1 - Planning", + "Environment :: Web Environment", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] + +[tool.poetry.dependencies] +python = "^3.8" +numpy = "^1.24.1" +munkres = "^1.1.4" +evaluate = "^0.4.0" +scikit-learn = "^1.2.0" + +[tool.poetry.dev-dependencies] +pytest = "^3.10.1" +pytest-cov = "^2.9.0" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.masonry.api" diff --git a/scripts.sh b/scripts.sh new file mode 100644 index 0000000..627f345 --- /dev/null +++ b/scripts.sh @@ -0,0 +1,2 @@ +python3 evaluate.py -g=gt/DUDE_demo_gt.json -s=submissions/DUDE_demo_submission_perfect.json +python3 evaluate.py -g=gt/DUDE_demo_gt.json -s=submissions/DUDE_demo_submission_imperfect.json diff --git a/submissions/DUDE_demo_submission_imperfect.json b/submissions/DUDE_demo_submission_imperfect.json new file mode 100644 index 0000000..76ab2d2 --- /dev/null +++ b/submissions/DUDE_demo_submission_imperfect.json @@ -0,0 +1,60 @@ +[ + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_489ae4ece55ed1af97a542845d3ba6d3", + "answers": ["No"], + "answers_confidence": [0.8] + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_b1816832f28da57fb863a185299e96b2", + "answers": ["3"], + "answers_confidence": [0.4] + }, + { + "questionId": "ca7c6dcd92fd573c8ad4cb79b0de525a_53fb47e11ac9c8c18ad9e53045d39039", + "answers": ["No"], + "answers_confidence": [0.99] + }, + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_4a863bed9b999dec79f6181f7a823e4a", + "answers": ["1894"], + "answers_confidence": [0.6] + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_8cd51553579ec302b46b4d2837143cf8", + "answers": ["1924"], + "answers_confidence": [0.89] + }, + { + "questionId": "95a82e8934b11e4dca41789f4a70caa3_7b58d9ec29ce6f75d82025df8e6aa91d", + "answers": ["Moscow - Russia - Terminal E - International"], + "answers_confidence": [0.4] + }, + { + "questionId": "32f61c097aa230d75bcb93248db87785_d966a58b08dfd5ba52166e5fa2f1ff67", + "answers": ["41%"], + "answers_confidence": [0.75] + }, + { + "questionId": "9a2f84b54389aa435486a84c89069942_75bdac76d12afee4a07c2a870dcfbbf5", + "answers": ["National Archives & Records Administration"], + "answers_confidence": [0.95] + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_9dc0d59810e040b8ec7516667aa289a4", + "answers": [""], + "answers_confidence": [0.3] + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_4fa9fe07c5b3f0680ef68da40c92e51a", + "answers": [""], + "answers_confidence": [0.3] + }, + { + "questionId": "manual_test", + "answers": [ + "AF0055", + "AF1044" + ], + "answers_confidence": [0.7] + } +] diff --git a/submissions/DUDE_demo_submission_perfect.json b/submissions/DUDE_demo_submission_perfect.json new file mode 100644 index 0000000..b76d666 --- /dev/null +++ b/submissions/DUDE_demo_submission_perfect.json @@ -0,0 +1,60 @@ +[ + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_489ae4ece55ed1af97a542845d3ba6d3", + "answers": ["Yes"], + "answers_confidence": [1] + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_b1816832f28da57fb863a185299e96b2", + "answers": ["2"], + "answers_confidence": [1] + }, + { + "questionId": "ca7c6dcd92fd573c8ad4cb79b0de525a_53fb47e11ac9c8c18ad9e53045d39039", + "answers": ["No"], + "answers_confidence": [1] + }, + { + "questionId": "7afe94621751eb3584a4a9962bb7b1f0_4a863bed9b999dec79f6181f7a823e4a", + "answers": ["1894"], + "answers_confidence": [1] + }, + { + "questionId": "c228ef8e4149d1532e833f76fe6de2d2_8cd51553579ec302b46b4d2837143cf8", + "answers": ["1924"], + "answers_confidence": [1] + }, + { + "questionId": "95a82e8934b11e4dca41789f4a70caa3_7b58d9ec29ce6f75d82025df8e6aa91d", + "answers": ["Moscow Sheremet, Russia - Terminal E - International"], + "answers_confidence": [1] + }, + { + "questionId": "32f61c097aa230d75bcb93248db87785_d966a58b08dfd5ba52166e5fa2f1ff67", + "answers": ["42%"], + "answers_confidence": [1] + }, + { + "questionId": "9a2f84b54389aa435486a84c89069942_75bdac76d12afee4a07c2a870dcfbbf5", + "answers": ["National Archives and Records Administration"], + "answers_confidence": [1] + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_9dc0d59810e040b8ec7516667aa289a4", + "answers": [""], + "answers_confidence": [1] + }, + { + "questionId": "ff1549ddf50c82edd169dee714a38339_4fa9fe07c5b3f0680ef68da40c92e51a", + "answers": [""], + "answers_confidence": [1] + }, + { + "questionId": "manual_test", + "answers": [ + "AF0055", + "AF1044" + ], + "answers_confidence": [1] + } +]