Skip to content

Commit

Permalink
add fbeta to eval
Browse files Browse the repository at this point in the history
  • Loading branch information
thorstenwagner committed Jan 28, 2025
1 parent c6ce88b commit 62fbebb
Showing 1 changed file with 45 additions and 14 deletions.
59 changes: 45 additions & 14 deletions scripts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def create_subvolume_parser(parser_parent):
required=True,
help="Path to reference embeddings file"
)

parser_parent.add_argument(
"--fbeta",
type=int,
default=1,
help="Beta for F-Statistic"
)
return parser_parent

def create_position_parser(parser_parent):
Expand Down Expand Up @@ -59,6 +66,13 @@ def create_position_parser(parser_parent):
help="Path to json file with box sizes for each reference"
)

parser_parent.add_argument(
"--fbeta",
type=int,
default=1,
help="Beta for F-Statistic"
)

parser_parent.add_argument(
"--optim",
action='store_true',
Expand Down Expand Up @@ -97,6 +111,15 @@ def create_parser():

return parser_parent


def get_fbeta(recall, precision, beta=1):
if precision == 0 and recall == 0:
return 0

f_score = np.nan_to_num((1 + beta * beta) * precision * recall / (beta * beta * precision + recall))

return f_score

def get_size_dict_json(pth: str) -> dict:
import json
with open(pth) as f:
Expand Down Expand Up @@ -159,7 +182,8 @@ def _add_size(df, size, size_dict = None) -> pd.DataFrame:

return df

def locate_positions_stats(locate_results, class_positions, iou_thresh):

def locate_positions_stats(locate_results, class_positions, iou_thresh, fbeta=1):
class_stats = {}
locate_results_np = locate_results[["X", "Y", "Z", "width", "height", "depth"]].to_numpy()
true_positive = 0
Expand Down Expand Up @@ -190,10 +214,7 @@ def locate_positions_stats(locate_results, class_positions, iou_thresh):

recall = true_positive / (true_positive + false_negative)
precision = true_positive / (true_positive + false_positive)
if precision == 0 and recall == 0:
f1_score = 0
else:
f1_score = np.nan_to_num(2 * precision * recall / (precision + recall))
f1_score = get_fbeta(recall, precision, beta=fbeta)
class_stats["F1"] = float(f1_score)
class_stats["Recall"] = recall
class_stats["Precision"] = float(precision)
Expand All @@ -211,6 +232,7 @@ def __init__(self):
@staticmethod
def stats_class_wise(ground_truth: List[int],
predicted: List[int],
fbeta=1
) -> Dict[int,Dict[str,float]]:

gt_classes = np.unique(ground_truth)
Expand All @@ -231,7 +253,7 @@ def stats_class_wise(ground_truth: List[int],
recall = true_positives/(true_positives+false_negative)

balanced_accuracy = (true_positive_rate+true_negative_rate)/2
f1_score = 2*precision*recall/(precision+recall)
f1_score = get_fbeta(recall, precision, beta=fbeta)
stats[gtclass]["F1"] = float(f1_score)
stats[gtclass]["Recall"] = float(recall)
stats[gtclass]["Precision"] = float(precision)
Expand All @@ -243,7 +265,8 @@ def stats_class_wise(ground_truth: List[int],
@staticmethod
def evaluate(probabilties: pd.DataFrame,
pdb_id_converter : Dict,
filename_pdb_extactor: Callable[[str], str]
filename_pdb_extactor: Callable[[str], str],
fbeta=1
) -> Dict[int,Dict[str,float]]:
'''
Calculates the statistics. Returns a dictonary with statistics for each class (identified by an integer).
Expand All @@ -261,7 +284,8 @@ def evaluate(probabilties: pd.DataFrame,
except KeyError:
print(f"No reference for class {pdb_str} not found. Skip it")
volume_ground_truth_classes.append(vclass)
stats = SubvolumeEvaluator.stats_class_wise(ground_truth=volume_ground_truth_classes,predicted=probabilties["predicted_class"])
stats = SubvolumeEvaluator.stats_class_wise(ground_truth=volume_ground_truth_classes,
predicted=probabilties["predicted_class"], fbeta=fbeta)
stats["VOID"] = np.sum(probabilties["predicted_class"]==-1)
return stats

Expand Down Expand Up @@ -328,12 +352,14 @@ def shrec_pdb_class_id_converter(reference_embeddings: List[str]):

class LocateOptimEvaluator():

def __init__(self, positions_path: str, locate_results_path: str, sizes_pth: str, stepsize_optim_similarity: float = 0.05, size=37):
def __init__(self, positions_path: str, locate_results_path: str, sizes_pth: str,
stepsize_optim_similarity: float = 0.05, size=37, fbeta=1):
self.positions_path = positions_path
self.locate_results_path = locate_results_path
self.size = size
self.iou_thresh = 0.6
self.size_dict= None
self.fbeta = fbeta
if sizes_pth:
self.size_dict = get_size_dict_json(sizes_pth)
self.stepsize_optim_similarity = stepsize_optim_similarity
Expand Down Expand Up @@ -372,7 +398,8 @@ def get_stats(self, df, positions):
df["class"] = class_name
df = _add_size(df, self.size, self.size_dict)

stats = locate_positions_stats(locate_results=df, class_positions=class_positions, iou_thresh=self.iou_thresh)
stats = locate_positions_stats(locate_results=df, class_positions=class_positions, iou_thresh=self.iou_thresh,
fbeta=self.fbeta)
return stats

def optim(self,locate_results, positions):
Expand Down Expand Up @@ -492,12 +519,13 @@ def run(self) -> Dict:

class LocateEvaluator():

def __init__(self, positions_path: str, locate_results_path: str):
def __init__(self, positions_path: str, locate_results_path: str, fbeta=1):
self.positions_path = positions_path
self.locate_results_path = locate_results_path
self.size = 37
self.iou_thresh = 0.6
self.size_dict = get_size_dict()
self.fbeta = fbeta


def run(self) -> Dict:
Expand All @@ -523,7 +551,7 @@ def run(self) -> Dict:
locate_results.columns =[["X","Y","Z","class"]]

locate_results = _add_size(locate_results, self.size, self.size_dict)
class_stats = locate_positions_stats(locate_results, class_positions, self.iou_thresh)
class_stats = locate_positions_stats(locate_results, class_positions, self.iou_thresh, self.fbeta)


stats[class_name] = class_stats
Expand Down Expand Up @@ -591,7 +619,9 @@ def _main_():
references,
)

stats = SubvolumeEvaluator.evaluate(probs, pdb_id_converter, filename_pdb_extactor=SubvolumeEvaluator.extract_pdb_from_filename)
stats = SubvolumeEvaluator.evaluate(probs, pdb_id_converter,
filename_pdb_extactor=SubvolumeEvaluator.extract_pdb_from_filename,
fbeta=args.fbeta)
SubvolumeEvaluator.print_stats(stats, pdb_id_converter,output_path=os.path.dirname(probabilties_path))

if "positions" in sys.argv[1]:
Expand All @@ -609,7 +639,8 @@ def _main_():
locate_results_path=locate_path,
sizes_pth=sizes_pth,
stepsize_optim_similarity=args.stepsize_optim_similarity,
size=size)
size=size,
fbeta=args.fbeta)
else:
evaluator = LocateEvaluator(positions_path=positions_path, locate_results_path=locate_path)
stats = evaluator.run()
Expand Down

0 comments on commit 62fbebb

Please sign in to comment.