diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index bdaef635d0..e029aca708 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -4,10 +4,11 @@ When working with the learner, you have the option to record metrics that will b throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | -| ---------------- | ------------------------------------------------------- | +|------------------|---------------------------------------------------------| | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | +| Recall | Calculate recall in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index 3eabc2a09b..14551169cb 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,5 +1,6 @@ -use crate::metric::processor::ItemLazy; -use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput}; +use crate::metric::{ + processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, +}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor, Transaction}; use burn_ndarray::NdArray; @@ -51,16 +52,16 @@ impl Adaptor> for ClassificationOutput { } } -impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> PrecisionInput { +impl Adaptor> for ClassificationOutput { + fn adapt(&self) -> ConfusionStatsInput { let [_, num_classes] = self.output.dims(); if num_classes > 1 { - PrecisionInput::new( + ConfusionStatsInput::new( self.output.clone(), self.targets.clone().one_hot(num_classes).bool(), ) } else { - PrecisionInput::new( + ConfusionStatsInput::new( self.output.clone(), self.targets.clone().unsqueeze_dim(1).bool(), ) @@ -115,8 +116,8 @@ impl Adaptor> for MultiLabelClassificationOutput { } } -impl Adaptor> for MultiLabelClassificationOutput { - fn adapt(&self) -> PrecisionInput { - PrecisionInput::new(self.output.clone(), self.targets.clone().bool()) +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> ConfusionStatsInput { + ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool()) } } diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index 24337498c5..5c06748e29 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -36,8 +36,9 @@ pub(crate) mod tests { /// Probability of tp before adding errors pub const THRESHOLD: f64 = 0.5; - #[derive(Debug)] + #[derive(Debug, Default)] pub enum ClassificationType { + #[default] Binary, Multiclass, Multilabel, @@ -51,12 +52,11 @@ pub(crate) mod tests { match classification_type { ClassificationType::Binary => { ( - Tensor::from_data( - [[0.3], [0.2], [0.7], [0.1], [0.55]], - //[[0], [0], [1], [0], [1]] with threshold=0.5 - &Default::default(), - ), + Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()), + // targets Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()), + // predictions @ threshold=0.5 + // [[0], [0], [1], [0], [1]] ) } ClassificationType::Multiclass => { @@ -69,12 +69,15 @@ pub(crate) mod tests { [0.1, 0.15, 0.8], [0.9, 0.03, 0.07], ], - //[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1 - //[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2 &Default::default(), ), Tensor::from_data( + // targets [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]], + // predictions @ top_k=1 + // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] + // predictions @ top_k=2 + // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] &Default::default(), ), ) @@ -89,11 +92,13 @@ pub(crate) mod tests { [0.7, 0.5, 0.9], [1.0, 0.3, 0.2], ], - //[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5 &Default::default(), ), + // targets Tensor::from_data( [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]], + // predictions @ threshold=0.5 + // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] &Default::default(), ), ) diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs index 1eb51a85d0..1caa2dabea 100644 --- a/crates/burn-train/src/metric/classification.rs +++ b/crates/burn-train/src/metric/classification.rs @@ -1,3 +1,26 @@ +use std::num::NonZeroUsize; + +/// Necessary data for classification metrics. +#[derive(Default)] +pub struct ClassificationMetricConfig { + pub decision_rule: DecisionRule, + pub class_reduction: ClassReduction, +} + +/// The prediction decision rule for classification metrics. +pub enum DecisionRule { + /// Consider a class predicted if its probability exceeds the threshold. + Threshold(f64), + /// Consider a class predicted correctly if it is within the top k predicted classes based on scores. + TopK(NonZeroUsize), +} + +impl Default for DecisionRule { + fn default() -> Self { + Self::Threshold(0.5) + } +} + /// The reduction strategy for classification metrics. #[derive(Copy, Clone, Default)] pub enum ClassReduction { diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index b0248c698b..3a3047a428 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -1,7 +1,27 @@ -use super::classification::ClassReduction; +use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}; use burn_core::prelude::{Backend, Bool, Int, Tensor}; use std::fmt::{self, Debug}; -use std::num::NonZeroUsize; + +/// Input for [ConfusionStats] +#[derive(new, Debug, Clone)] +pub struct ConfusionStatsInput { + /// Sample x Class Non thresholded normalized predictions. + pub predictions: Tensor, + /// Sample x Class one-hot encoded target. + pub targets: Tensor, +} + +impl From> for (Tensor, Tensor) { + fn from(input: ConfusionStatsInput) -> Self { + (input.predictions, input.targets) + } +} + +impl From<(Tensor, Tensor)> for ConfusionStatsInput { + fn from(value: (Tensor, Tensor)) -> Self { + Self::new(value.0, value.1) + } +} #[derive(Clone)] pub struct ConfusionStats { @@ -31,28 +51,24 @@ impl Debug for ConfusionStats { impl ConfusionStats { /// Expects `predictions` to be normalized. - pub fn new( - predictions: Tensor, - targets: Tensor, - threshold: Option, - top_k: Option, - class_reduction: ClassReduction, - ) -> Self { - let prediction_mask = match (threshold, top_k) { - (Some(threshold), None) => { - predictions.greater_elem(threshold) - }, - (None, Some(top_k)) => { - let mask = predictions.zeros_like(); - let indexes = predictions.argsort_descending(1).narrow(1, 0, top_k.get()); + pub fn new(input: &ConfusionStatsInput, config: &ClassificationMetricConfig) -> Self { + let prediction_mask = match config.decision_rule { + DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold), + DecisionRule::TopK(top_k) => { + let mask = input.predictions.zeros_like(); + let indexes = + input + .predictions + .clone() + .argsort_descending(1) + .narrow(1, 0, top_k.get()); let values = indexes.ones_like().float(); mask.scatter(1, indexes, values).bool() } - _ => panic!("Either threshold (for binary or multilabel) or top_k (for multiclass) must be set."), }; Self { - confusion_classes: prediction_mask.int() + targets.int() * 2, - class_reduction, + confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2, + class_reduction: config.class_reduction, } } @@ -61,7 +77,7 @@ impl ConfusionStats { sample_class_mask: Tensor, class_reduction: ClassReduction, ) -> Tensor { - use ClassReduction::*; + use ClassReduction::{Macro, Micro}; match class_reduction { Micro => sample_class_mask.float().sum(), Macro => sample_class_mask.float().sum_dim(0).squeeze(0), @@ -107,245 +123,225 @@ impl ConfusionStats { #[cfg(test)] mod tests { - use super::{ - ClassReduction::{self, *}, - ConfusionStats, - }; - use crate::tests::{ - dummy_classification_input, - ClassificationType::{self, *}, - THRESHOLD, + use super::{ConfusionStats, ConfusionStatsInput}; + use crate::{ + metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + tests::{dummy_classification_input, ClassificationType, THRESHOLD}, + TestBackend, }; use burn_core::prelude::TensorData; - use rstest::rstest; + use rstest::{fixture, rstest}; use std::num::NonZeroUsize; - #[rstest] - #[should_panic] - #[case::both_some(Some(THRESHOLD), Some(1))] - #[should_panic] - #[case::both_none(None, None)] - fn test_exclusive_threshold_top_k( - #[case] threshold: Option, - #[case] top_k: Option, - ) { - let (predictions, targets) = dummy_classification_input(&Binary); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - Micro, - ); + fn top_k_config( + top_k: NonZeroUsize, + class_reduction: ClassReduction, + ) -> ClassificationMetricConfig { + ClassificationMetricConfig { + decision_rule: DecisionRule::TopK(top_k), + class_reduction, + } + } + #[fixture] + #[once] + fn top_k_config_k1_micro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro) + } + + #[fixture] + #[once] + fn top_k_config_k1_macro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro) + } + #[fixture] + #[once] + fn top_k_config_k2_micro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro) + } + #[fixture] + #[once] + fn top_k_config_k2_macro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro) + } + + fn threshold_config( + threshold: f64, + class_reduction: ClassReduction, + ) -> ClassificationMetricConfig { + ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + } + } + #[fixture] + #[once] + fn threshold_config_micro() -> ClassificationMetricConfig { + threshold_config(THRESHOLD, ClassReduction::Micro) + } + #[fixture] + #[once] + fn threshold_config_macro() -> ClassificationMetricConfig { + threshold_config(THRESHOLD, ClassReduction::Macro) } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [3].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [5].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 2, 1].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())] fn test_true_positive( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .true_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .true_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [8].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 3, 3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [1, 1, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [0, 2, 1].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())] fn test_true_negative( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .true_negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .true_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 0].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [6].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 3, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 1, 1].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())] fn test_false_positive( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .false_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .false_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 0, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [1].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [0, 0, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [4].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 0, 2].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())] fn test_false_negatives( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .false_negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .false_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 1, 2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [9].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [4, 2, 3].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())] fn test_positive( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [3].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [3, 4, 3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [3, 4, 3].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [6].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 3, 2].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())] fn test_negative( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 2, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [4, 4, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [8].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [3, 3, 2].into())] + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())] fn test_predicted_positive( #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, + #[case] config: ClassificationMetricConfig, #[case] expected: Vec, ) { - let (predictions, targets) = dummy_classification_input(&classification_type); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.and_then(NonZeroUsize::new), - class_reduction, - ) - .predicted_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) + .predicted_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } } diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 33049f9a46..e6358e3023 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -51,7 +51,12 @@ pub(crate) mod classification; #[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; mod confusion_stats; +pub use confusion_stats::ConfusionStatsInput; #[cfg(feature = "metrics")] mod precision; #[cfg(feature = "metrics")] pub use precision::*; +#[cfg(feature = "metrics")] +mod recall; +#[cfg(feature = "metrics")] +pub use recall::*; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 0b9efb3d8d..067261cbdf 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -1,56 +1,22 @@ use super::{ - classification::ClassReduction, - confusion_stats::ConfusionStats, + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, }; use burn_core::{ prelude::{Backend, Tensor}, - tensor::{cast::ToElement, Bool}, + tensor::cast::ToElement, }; use core::marker::PhantomData; use std::num::NonZeroUsize; -/// Input for precision metric. -#[derive(new, Debug, Clone)] -pub struct PrecisionInput { - /// Sample x Class Non thresholded normalized predictions. - pub predictions: Tensor, - /// Sample x Class one-hot encoded target. - pub targets: Tensor, -} - -impl From> for (Tensor, Tensor) { - fn from(input: PrecisionInput) -> Self { - (input.predictions, input.targets) - } -} - -impl From<(Tensor, Tensor)> for PrecisionInput { - fn from(value: (Tensor, Tensor)) -> Self { - Self::new(value.0, value.1) - } -} - -enum PrecisionConfig { - Binary { threshold: f64 }, - Multiclass { top_k: NonZeroUsize }, - Multilabel { threshold: f64 }, -} - -impl Default for PrecisionConfig { - fn default() -> Self { - Self::Binary { threshold: 0.5 } - } -} - ///The Precision Metric #[derive(Default)] pub struct PrecisionMetric { state: NumericMetricState, _b: PhantomData, - class_reduction: ClassReduction, - config: PrecisionConfig, + config: ClassificationMetricConfig, } impl PrecisionMetric { @@ -62,7 +28,11 @@ impl PrecisionMetric { #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self { - config: PrecisionConfig::Binary { threshold }, + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction + ..Default::default() + }, ..Default::default() } } @@ -73,10 +43,13 @@ impl PrecisionMetric { /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). #[allow(dead_code)] - pub fn multiclass(top_k: usize) -> Self { + pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { - config: PrecisionConfig::Multiclass { - top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + config: ClassificationMetricConfig { + decision_rule: DecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), + class_reduction, }, ..Default::default() } @@ -86,25 +59,21 @@ impl PrecisionMetric { /// /// # Arguments /// - /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `threshold` - The threshold to transform a probability into a binary value. #[allow(dead_code)] - pub fn multilabel(threshold: f64) -> Self { + pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { - config: PrecisionConfig::Multilabel { threshold }, + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + }, ..Default::default() } } - /// Sets the class reduction method. - #[allow(dead_code)] - pub fn with_class_reduction(mut self, class_reduction: ClassReduction) -> Self { - self.class_reduction = class_reduction; - self - } - fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { - use ClassReduction::*; - let avg_tensor = match self.class_reduction { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric.contains_nan().any().into_scalar() { @@ -122,21 +91,12 @@ impl PrecisionMetric { impl Metric for PrecisionMetric { const NAME: &'static str = "Precision"; - type Input = PrecisionInput; + type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - let (predictions, targets) = input.clone().into(); let [sample_size, _] = input.predictions.dims(); - let (threshold, top_k) = match self.config { - PrecisionConfig::Binary { threshold } | PrecisionConfig::Multilabel { threshold } => { - (Some(threshold), None) - } - PrecisionConfig::Multiclass { top_k } => (None, Some(top_k)), - }; - - let cf_stats = - ConfusionStats::new(predictions, targets, threshold, top_k, self.class_reduction); + let cf_stats = ConfusionStats::new(input, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive()); @@ -169,15 +129,10 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_micro(Micro, THRESHOLD, 0.5)] - #[case::binary_macro(Macro, THRESHOLD, 0.5)] - fn test_binary_precision( - #[case] class_reduction: ClassReduction, - #[case] threshold: f64, - #[case] expected: f64, - ) { + #[case::binary_macro(THRESHOLD, 0.5)] + fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); - let mut metric = PrecisionMetric::binary(threshold).with_class_reduction(class_reduction); + let mut metric = PrecisionMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) @@ -194,7 +149,7 @@ mod tests { #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multiclass).into(); - let mut metric = PrecisionMetric::multiclass(top_k).with_class_reduction(class_reduction); + let mut metric = PrecisionMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) @@ -203,14 +158,13 @@ mod tests { #[rstest] #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)] #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)] - fn test_precision( + fn test_multilabel_precision( #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); - let mut metric = - PrecisionMetric::multilabel(threshold).with_class_reduction(class_reduction); + let mut metric = PrecisionMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs new file mode 100644 index 0000000000..8ce4351396 --- /dev/null +++ b/crates/burn-train/src/metric/recall.rs @@ -0,0 +1,171 @@ +use super::{ + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::cast::ToElement, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +///The Precision Metric +#[derive(Default)] +pub struct RecallMetric { + state: NumericMetricState, + _b: PhantomData, + config: ClassificationMetricConfig, +} + +impl RecallMetric { + /// Recall metric for binary classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(threshold: f64) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction + ..Default::default() + }, + ..Default::default() + } + } + + /// Recall metric for multiclass classification. + /// + /// # Arguments + /// + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + #[allow(dead_code)] + pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), + class_reduction, + }, + ..Default::default() + } + } + + /// Recall metric for multi-label classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + }, + ..Default::default() + } + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for RecallMetric { + const NAME: &'static str = "Recall"; + type Input = ConfusionStatsInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let [sample_size, _] = input.predictions.dims(); + + let cf_stats = ConfusionStats::new(input, &self.config); + let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive()); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for RecallMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + Metric, MetricMetadata, Numeric, RecallMetric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_macro(THRESHOLD, 0.5)] + fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = RecallMetric::binary(threshold); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)] + #[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)] + #[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)] + #[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)] + fn test_multiclass_recall( + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = RecallMetric::multiclass(top_k, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)] + #[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)] + fn test_multilabel_recall( + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = RecallMetric::multilabel(threshold, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +}