diff --git a/mean_average_precision/detection_map.py b/mean_average_precision/detection_map.py index a4475f2..4782652 100644 --- a/mean_average_precision/detection_map.py +++ b/mean_average_precision/detection_map.py @@ -157,3 +157,17 @@ def plot(self, interpolated=True): plt.suptitle("Mean average precision : {:0.2f}".format(sum(mean_average_precision)/len(mean_average_precision))) fig.tight_layout() + + def compute_map(self, interpolated=True): + """ + Compute average precision per class and the mAP score. + :param interpolated: will compute the interpolated curve + :return: tuple: (dict of class -> AP, mAP score) + """ + ap_by_class = {} + for class_idx in range(self.n_class): + precisions, recalls = self.compute_precision_recall_(class_idx, interpolated=interpolated) + average_precision = self.compute_ap(precisions, recalls) + ap_by_class[class_idx] = average_precision + mAP = sum(ap_by_class.values()) / len(ap_by_class) + return ap_by_class, mAP