diff --git a/scripts/evaluation.py b/scripts/evaluation.py index fb83d7d..bfe13f4 100644 --- a/scripts/evaluation.py +++ b/scripts/evaluation.py @@ -31,6 +31,13 @@ def create_subvolume_parser(parser_parent): required=True, help="Path to reference embeddings file" ) + + parser_parent.add_argument( + "--fbeta", + type=int, + default=1, + help="Beta for F-Statistic" + ) return parser_parent def create_position_parser(parser_parent): @@ -59,6 +66,13 @@ def create_position_parser(parser_parent): help="Path to json file with box sizes for each reference" ) + parser_parent.add_argument( + "--fbeta", + type=int, + default=1, + help="Beta for F-Statistic" + ) + parser_parent.add_argument( "--optim", action='store_true', @@ -97,6 +111,15 @@ def create_parser(): return parser_parent + +def get_fbeta(recall, precision, beta=1): + if precision == 0 and recall == 0: + return 0 + + f_score = np.nan_to_num((1 + beta * beta) * precision * recall / (beta * beta * precision + recall)) + + return f_score + def get_size_dict_json(pth: str) -> dict: import json with open(pth) as f: @@ -159,7 +182,8 @@ def _add_size(df, size, size_dict = None) -> pd.DataFrame: return df -def locate_positions_stats(locate_results, class_positions, iou_thresh): + +def locate_positions_stats(locate_results, class_positions, iou_thresh, fbeta=1): class_stats = {} locate_results_np = locate_results[["X", "Y", "Z", "width", "height", "depth"]].to_numpy() true_positive = 0 @@ -190,10 +214,7 @@ def locate_positions_stats(locate_results, class_positions, iou_thresh): recall = true_positive / (true_positive + false_negative) precision = true_positive / (true_positive + false_positive) - if precision == 0 and recall == 0: - f1_score = 0 - else: - f1_score = np.nan_to_num(2 * precision * recall / (precision + recall)) + f1_score = get_fbeta(recall, precision, beta=fbeta) class_stats["F1"] = float(f1_score) class_stats["Recall"] = recall class_stats["Precision"] = float(precision) @@ -211,6 +232,7 @@ def __init__(self): @staticmethod def stats_class_wise(ground_truth: List[int], predicted: List[int], + fbeta=1 ) -> Dict[int,Dict[str,float]]: gt_classes = np.unique(ground_truth) @@ -231,7 +253,7 @@ def stats_class_wise(ground_truth: List[int], recall = true_positives/(true_positives+false_negative) balanced_accuracy = (true_positive_rate+true_negative_rate)/2 - f1_score = 2*precision*recall/(precision+recall) + f1_score = get_fbeta(recall, precision, beta=fbeta) stats[gtclass]["F1"] = float(f1_score) stats[gtclass]["Recall"] = float(recall) stats[gtclass]["Precision"] = float(precision) @@ -243,7 +265,8 @@ def stats_class_wise(ground_truth: List[int], @staticmethod def evaluate(probabilties: pd.DataFrame, pdb_id_converter : Dict, - filename_pdb_extactor: Callable[[str], str] + filename_pdb_extactor: Callable[[str], str], + fbeta=1 ) -> Dict[int,Dict[str,float]]: ''' Calculates the statistics. Returns a dictonary with statistics for each class (identified by an integer). @@ -261,7 +284,8 @@ def evaluate(probabilties: pd.DataFrame, except KeyError: print(f"No reference for class {pdb_str} not found. Skip it") volume_ground_truth_classes.append(vclass) - stats = SubvolumeEvaluator.stats_class_wise(ground_truth=volume_ground_truth_classes,predicted=probabilties["predicted_class"]) + stats = SubvolumeEvaluator.stats_class_wise(ground_truth=volume_ground_truth_classes, + predicted=probabilties["predicted_class"], fbeta=fbeta) stats["VOID"] = np.sum(probabilties["predicted_class"]==-1) return stats @@ -328,12 +352,14 @@ def shrec_pdb_class_id_converter(reference_embeddings: List[str]): class LocateOptimEvaluator(): - def __init__(self, positions_path: str, locate_results_path: str, sizes_pth: str, stepsize_optim_similarity: float = 0.05, size=37): + def __init__(self, positions_path: str, locate_results_path: str, sizes_pth: str, + stepsize_optim_similarity: float = 0.05, size=37, fbeta=1): self.positions_path = positions_path self.locate_results_path = locate_results_path self.size = size self.iou_thresh = 0.6 self.size_dict= None + self.fbeta = fbeta if sizes_pth: self.size_dict = get_size_dict_json(sizes_pth) self.stepsize_optim_similarity = stepsize_optim_similarity @@ -372,7 +398,8 @@ def get_stats(self, df, positions): df["class"] = class_name df = _add_size(df, self.size, self.size_dict) - stats = locate_positions_stats(locate_results=df, class_positions=class_positions, iou_thresh=self.iou_thresh) + stats = locate_positions_stats(locate_results=df, class_positions=class_positions, iou_thresh=self.iou_thresh, + fbeta=self.fbeta) return stats def optim(self,locate_results, positions): @@ -492,12 +519,13 @@ def run(self) -> Dict: class LocateEvaluator(): - def __init__(self, positions_path: str, locate_results_path: str): + def __init__(self, positions_path: str, locate_results_path: str, fbeta=1): self.positions_path = positions_path self.locate_results_path = locate_results_path self.size = 37 self.iou_thresh = 0.6 self.size_dict = get_size_dict() + self.fbeta = fbeta def run(self) -> Dict: @@ -523,7 +551,7 @@ def run(self) -> Dict: locate_results.columns =[["X","Y","Z","class"]] locate_results = _add_size(locate_results, self.size, self.size_dict) - class_stats = locate_positions_stats(locate_results, class_positions, self.iou_thresh) + class_stats = locate_positions_stats(locate_results, class_positions, self.iou_thresh, self.fbeta) stats[class_name] = class_stats @@ -591,7 +619,9 @@ def _main_(): references, ) - stats = SubvolumeEvaluator.evaluate(probs, pdb_id_converter, filename_pdb_extactor=SubvolumeEvaluator.extract_pdb_from_filename) + stats = SubvolumeEvaluator.evaluate(probs, pdb_id_converter, + filename_pdb_extactor=SubvolumeEvaluator.extract_pdb_from_filename, + fbeta=args.fbeta) SubvolumeEvaluator.print_stats(stats, pdb_id_converter,output_path=os.path.dirname(probabilties_path)) if "positions" in sys.argv[1]: @@ -609,7 +639,8 @@ def _main_(): locate_results_path=locate_path, sizes_pth=sizes_pth, stepsize_optim_similarity=args.stepsize_optim_similarity, - size=size) + size=size, + fbeta=args.fbeta) else: evaluator = LocateEvaluator(positions_path=positions_path, locate_results_path=locate_path) stats = evaluator.run()