Skip to content

Commit 9545403

Browse files
committed
Changed file name and added balanced accuracy calculation
1 parent cc43c73 commit 9545403

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

evaluate-accuracy.py evaluate.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,12 @@ def upload_image(gl: Groundlight, detector: Detector, image: PIL) -> BinaryClass
6060
logger.info(f"Evaluating {len(dataset)} images on detector {detector.name} with delay {args.delay}.")
6161

6262
# Record the number of correct predictions
63-
# Also record the number of false positives and false negatives
64-
correct = 0
65-
total_processed = 0
63+
# Also record the number of TP, TN, FP, FN for calculating balanced accuracy, precision, and recall
64+
true_positives = 0
65+
true_negatives = 0
6666
false_positives = 0
6767
false_negatives = 0
68+
total_processed = 0
6869
average_confidence = 0
6970

7071
for image_name, label in tqdm(dataset.values):
@@ -79,11 +80,13 @@ def upload_image(gl: Groundlight, detector: Detector, image: PIL) -> BinaryClass
7980
image = PIL.Image.open(os.path.join(args.dataset, "images", image_name))
8081
result = upload_image(gl=gl, detector=detector, image=image)
8182

82-
if result.label == label:
83-
correct += 1
84-
elif result.label == "YES" and label == "NO":
83+
if result.label == "YES" and label == "YES":
84+
true_positives += 1
85+
elif result.label == "NO" and label == "NO":
86+
true_negatives += 1
87+
elif result.label == "YES" and label == "NO":
8588
false_positives += 1
86-
elif result.label == "NO" and label == "YES":
89+
elif result.label == "NO" and label == "YES":
8790
false_negatives += 1
8891

8992
average_confidence += result.confidence
@@ -92,15 +95,12 @@ def upload_image(gl: Groundlight, detector: Detector, image: PIL) -> BinaryClass
9295
time.sleep(args.delay)
9396

9497
# Calculate the accuracy, precision, and recall
95-
accuracy = correct / total_processed if total_processed > 0 else 0
96-
precision = correct / (correct + false_positives) if correct + false_positives > 0 else 0
97-
recall = correct / (correct + false_negatives) if correct + false_negatives > 0 else 0
98+
balanced_accuracy = (true_positives / (true_positives + false_negatives) + true_negatives / (true_negatives + false_positives)) / 2
99+
precision = true_positives / (true_positives + false_positives)
100+
recall = true_positives / (true_positives + false_negatives)
98101

99102
logger.info(f"Processed {total_processed} images.")
100-
logger.info(f"Correct: {correct}/{total_processed}")
101103
logger.info(f"Average Confidence: {average_confidence / total_processed:.2f}")
102-
logger.info(f"False Positives: {false_positives}")
103-
logger.info(f"False Negatives: {false_negatives}")
104-
logger.info(f"Accuracy: {accuracy:.2f}")
104+
logger.info(f"Balanced Accuracy: {balanced_accuracy:.2f}")
105105
logger.info(f"Precision: {precision:.2f}")
106106
logger.info(f"Recall: {recall:.2f}")

0 commit comments

Comments
 (0)