From 3243d39074a4e5ef11dab63e6c05a71ca0d96030 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Wed, 20 Nov 2024 20:50:07 +0100 Subject: [PATCH 1/9] create ClassificationInput and ClassificationConfig to share code with RecallMetric. Adapt confusionStats and Precision with generalized code. --- .../burn-train/src/learner/classification.rs | 18 +- crates/burn-train/src/lib.rs | 14 +- .../burn-train/src/metric/classification.rs | 48 ++ .../burn-train/src/metric/confusion_stats.rs | 416 ++++++++---------- crates/burn-train/src/metric/precision.rs | 86 +--- 5 files changed, 281 insertions(+), 301 deletions(-) diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index 381cd3a96c..354a329e47 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,4 +1,6 @@ -use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput}; +use crate::metric::{ + classification::ClassificationInput, AccuracyInput, Adaptor, HammingScoreInput, LossInput, +}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; @@ -27,16 +29,16 @@ impl Adaptor> for ClassificationOutput { } } -impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> PrecisionInput { +impl Adaptor> for ClassificationOutput { + fn adapt(&self) -> ClassificationInput { let [_, num_classes] = self.output.dims(); if num_classes > 1 { - PrecisionInput::new( + ClassificationInput::new( self.output.clone(), self.targets.clone().one_hot(num_classes).bool(), ) } else { - PrecisionInput::new( + ClassificationInput::new( self.output.clone(), self.targets.clone().unsqueeze_dim(1).bool(), ) @@ -69,8 +71,8 @@ impl Adaptor> for MultiLabelClassificationOutput { } } -impl Adaptor> for MultiLabelClassificationOutput { - fn adapt(&self) -> PrecisionInput { - PrecisionInput::new(self.output.clone(), self.targets.clone().bool()) +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> ClassificationInput { + ClassificationInput::new(self.output.clone(), self.targets.clone().bool()) } } diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index 24337498c5..d47420750d 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -29,6 +29,7 @@ pub(crate) type TestBackend = burn_ndarray::NdArray; #[cfg(test)] pub(crate) mod tests { + use crate::metric::classification::ClassificationConfig; use crate::TestBackend; use burn_core::{prelude::Tensor, tensor::Bool}; use std::default::Default; @@ -36,13 +37,24 @@ pub(crate) mod tests { /// Probability of tp before adding errors pub const THRESHOLD: f64 = 0.5; - #[derive(Debug)] + #[derive(Debug, Default)] pub enum ClassificationType { + #[default] Binary, Multiclass, Multilabel, } + impl ClassificationType { + pub fn from_classification_config(config: &ClassificationConfig) -> Self { + match config { + ClassificationConfig::Binary { .. } => ClassificationType::Binary, + ClassificationConfig::Multiclass { .. } => ClassificationType::Multiclass, + ClassificationConfig::Multilabel { .. } => ClassificationType::Multilabel, + } + } + } + /// Sample x Class shaped matrix for use in /// classification metrics testing pub fn dummy_classification_input( diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs index 1eb51a85d0..3b3e19df61 100644 --- a/crates/burn-train/src/metric/classification.rs +++ b/crates/burn-train/src/metric/classification.rs @@ -1,3 +1,6 @@ +use burn_core::prelude::{Backend, Bool, Tensor}; +use std::num::NonZeroUsize; + /// The reduction strategy for classification metrics. #[derive(Copy, Clone, Default)] pub enum ClassReduction { @@ -7,3 +10,48 @@ pub enum ClassReduction { #[default] Macro, } + +/// Input for classification metrics +#[derive(new, Debug, Clone)] +pub struct ClassificationInput { + /// Sample x Class Non thresholded normalized predictions. + pub predictions: Tensor, + /// Sample x Class one-hot encoded target. + pub targets: Tensor, +} + +impl From> for (Tensor, Tensor) { + fn from(input: ClassificationInput) -> Self { + (input.predictions, input.targets) + } +} + +impl From<(Tensor, Tensor)> for ClassificationInput { + fn from(value: (Tensor, Tensor)) -> Self { + Self::new(value.0, value.1) + } +} + +pub enum ClassificationConfig { + Binary { + threshold: f64, + class_reduction: ClassReduction, + }, + Multiclass { + top_k: NonZeroUsize, + class_reduction: ClassReduction, + }, + Multilabel { + threshold: f64, + class_reduction: ClassReduction, + }, +} + +impl Default for ClassificationConfig { + fn default() -> Self { + Self::Binary { + threshold: 0.5, + class_reduction: Default::default(), + } + } +} diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index cdb01b1721..db80976db7 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -1,7 +1,6 @@ -use super::classification::ClassReduction; +use super::classification::{ClassReduction, ClassificationConfig}; use burn_core::prelude::{Backend, Bool, Int, Tensor}; use std::fmt::{self, Debug}; -use std::num::NonZeroUsize; #[derive(Clone)] pub struct ConfusionStats { @@ -34,21 +33,26 @@ impl ConfusionStats { pub fn new( predictions: Tensor, targets: Tensor, - threshold: Option, - top_k: Option, - class_reduction: ClassReduction, + config: &ClassificationConfig, ) -> Self { - let prediction_mask = match (threshold, top_k) { - (Some(threshold), None) => { - predictions.greater_elem(threshold) - }, - (None, Some(top_k)) => { + let (prediction_mask, class_reduction) = match config { + ClassificationConfig::Binary { + threshold, + class_reduction, + } + | ClassificationConfig::Multilabel { + threshold, + class_reduction, + } => (predictions.greater_elem(*threshold), *class_reduction), + ClassificationConfig::Multiclass { + top_k, + class_reduction, + } => { let mask = predictions.zeros_like(); let indexes = predictions.argsort_descending(1).narrow(1, 0, top_k.get()); let values = indexes.ones_like().float(); - mask.scatter(1, indexes, values).bool() + (mask.scatter(1, indexes, values).bool(), *class_reduction) } - _ => panic!("Either threshold (for binary or multilabel) or top_k (for multiclass) must be set."), }; Self { confusion_classes: prediction_mask.int() + targets.int() * 2, @@ -107,245 +111,201 @@ impl ConfusionStats { #[cfg(test)] mod tests { - use super::{ - ClassReduction::{self, *}, - ConfusionStats, - }; - use crate::tests::{ - dummy_classification_input, - ClassificationType::{self, *}, - THRESHOLD, - }; + use super::ConfusionStats; + use crate::metric::classification::{ClassReduction::*, ClassificationConfig}; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; use burn_core::prelude::TensorData; - use rstest::rstest; + use rstest::{fixture, rstest}; use std::num::NonZeroUsize; - #[rstest] - #[should_panic] - #[case::both_some(Some(THRESHOLD), Some(1))] - #[should_panic] - #[case::both_none(None, None)] - fn test_exclusive_threshold_top_k( - #[case] threshold: Option, - #[case] top_k: Option, - ) { - let (predictions, targets) = dummy_classification_input(&Binary).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - Micro, - ); + // binary classification results are the same independently of class_reduction + #[fixture] + #[once] + fn binary_config() -> ClassificationConfig { + ClassificationConfig::Binary { + threshold: THRESHOLD, + class_reduction: Default::default(), + } + } + #[fixture] + #[once] + fn multiclass_config_k1_micro() -> ClassificationConfig { + ClassificationConfig::Multiclass { + top_k: NonZeroUsize::new(1).unwrap(), + class_reduction: Micro, + } + } + #[fixture] + #[once] + fn multiclass_config_k1_macro() -> ClassificationConfig { + ClassificationConfig::Multiclass { + top_k: NonZeroUsize::new(1).unwrap(), + class_reduction: Macro, + } + } + #[fixture] + #[once] + fn multiclass_config_k2_micro() -> ClassificationConfig { + ClassificationConfig::Multiclass { + top_k: NonZeroUsize::new(2).unwrap(), + class_reduction: Micro, + } + } + #[fixture] + #[once] + fn multiclass_config_k2_macro() -> ClassificationConfig { + ClassificationConfig::Multiclass { + top_k: NonZeroUsize::new(2).unwrap(), + class_reduction: Macro, + } + } + #[fixture] + #[once] + fn multilabel_config_micro() -> ClassificationConfig { + ClassificationConfig::Multilabel { + threshold: THRESHOLD, + class_reduction: Micro, + } + } + #[fixture] + #[once] + fn multilabel_config_macro() -> ClassificationConfig { + ClassificationConfig::Multilabel { + threshold: THRESHOLD, + class_reduction: Macro, + } } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [3].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [5].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 2, 1].into())] - fn test_true_positive( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .true_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_micro(binary_config(), [1].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [3].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 1, 1].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [4].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 1, 1].into())] + #[case::multilabel_micro(multilabel_config_micro(), [5].into())] + #[case::multilabel_macro(multilabel_config_macro(), [2, 2, 1].into())] + fn test_true_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .true_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [8].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 3, 3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [1, 1, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [0, 2, 1].into())] - fn test_true_negative( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .true_negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_macro(binary_config(), [2].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [8].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 3, 3].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [4].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [1, 1, 2].into())] + #[case::multilabel_micro(multilabel_config_micro(), [3].into())] + #[case::multilabel_macro(multilabel_config_macro(), [0, 2, 1].into())] + fn test_true_negative(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .true_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 0].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [6].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 3, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 1, 1].into())] - fn test_false_positive( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .false_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_macro(binary_config(), [1].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [2].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 1, 0].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [6].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 3, 1].into())] + #[case::multilabel_micro(multilabel_config_micro(), [3].into())] + #[case::multilabel_macro(multilabel_config_macro(), [1, 1, 1].into())] + fn test_false_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .false_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 0, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [1].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [0, 0, 1].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [4].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 0, 2].into())] - fn test_false_negatives( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .false_negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_macro(binary_config(), [1].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [2].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 0, 1].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [1].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [0, 0, 1].into())] + #[case::multilabel_micro(multilabel_config_micro(), [4].into())] + #[case::multilabel_macro(multilabel_config_macro(), [2, 0, 2].into())] + fn test_false_negatives(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .false_negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 1, 2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [9].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [4, 2, 3].into())] - fn test_positive( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_micro(binary_config(), [2].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [5].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 1, 2].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [5].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 1, 2].into())] + #[case::multilabel_micro(multilabel_config_micro(), [9].into())] + #[case::multilabel_macro(multilabel_config_macro(), [4, 2, 3].into())] + fn test_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [3].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [3, 4, 3].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [3, 4, 3].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [6].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 3, 2].into())] - fn test_negative( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .negative() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_micro(binary_config(), [3].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [10].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [3, 4, 3].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [10].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [3, 4, 3].into())] + #[case::multilabel_micro(multilabel_config_micro(), [6].into())] + #[case::multilabel_macro(multilabel_config_macro(), [1, 3, 2].into())] + fn test_negative(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .negative() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } #[rstest] - #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())] - #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 2, 1].into())] - #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())] - #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [4, 4, 2].into())] - #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [8].into())] - #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [3, 3, 2].into())] - fn test_predicted_positive( - #[case] classification_type: ClassificationType, - #[case] class_reduction: ClassReduction, - #[case] threshold: Option, - #[case] top_k: Option, - #[case] expected: Vec, - ) { - let (predictions, targets) = dummy_classification_input(&classification_type).into(); - ConfusionStats::new( - predictions, - targets, - threshold, - top_k.map(NonZeroUsize::new).flatten(), - class_reduction, - ) - .predicted_positive() - .int() - .into_data() - .assert_eq(&TensorData::from(expected.as_slice()), true); + #[case::binary_micro(binary_config(), [2].into())] + #[case::multiclass_micro(multiclass_config_k1_micro(), [5].into())] + #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 2, 1].into())] + #[case::multiclass_micro(multiclass_config_k2_micro(), [10].into())] + #[case::multiclass_macro(multiclass_config_k2_macro(), [4, 4, 2].into())] + #[case::multilabel_micro(multilabel_config_micro(), [8].into())] + #[case::multilabel_macro(multilabel_config_macro(), [3, 3, 2].into())] + fn test_predicted_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { + let (predictions, targets) = + dummy_classification_input(&ClassificationType::from_classification_config(&config)) + .into(); + ConfusionStats::new(predictions, targets, &config) + .predicted_positive() + .int() + .into_data() + .assert_eq(&TensorData::from(expected.as_slice()), true); } } diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 0b9efb3d8d..12ea7dee6f 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -1,56 +1,24 @@ use super::{ - classification::ClassReduction, + classification::{ClassReduction, ClassificationConfig}, confusion_stats::ConfusionStats, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, }; +use crate::metric::classification::ClassificationInput; use burn_core::{ prelude::{Backend, Tensor}, - tensor::{cast::ToElement, Bool}, + tensor::cast::ToElement, }; use core::marker::PhantomData; use std::num::NonZeroUsize; -/// Input for precision metric. -#[derive(new, Debug, Clone)] -pub struct PrecisionInput { - /// Sample x Class Non thresholded normalized predictions. - pub predictions: Tensor, - /// Sample x Class one-hot encoded target. - pub targets: Tensor, -} - -impl From> for (Tensor, Tensor) { - fn from(input: PrecisionInput) -> Self { - (input.predictions, input.targets) - } -} - -impl From<(Tensor, Tensor)> for PrecisionInput { - fn from(value: (Tensor, Tensor)) -> Self { - Self::new(value.0, value.1) - } -} - -enum PrecisionConfig { - Binary { threshold: f64 }, - Multiclass { top_k: NonZeroUsize }, - Multilabel { threshold: f64 }, -} - -impl Default for PrecisionConfig { - fn default() -> Self { - Self::Binary { threshold: 0.5 } - } -} - ///The Precision Metric #[derive(Default)] pub struct PrecisionMetric { state: NumericMetricState, _b: PhantomData, class_reduction: ClassReduction, - config: PrecisionConfig, + config: ClassificationConfig, } impl PrecisionMetric { @@ -62,7 +30,10 @@ impl PrecisionMetric { #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self { - config: PrecisionConfig::Binary { threshold }, + config: ClassificationConfig::Binary { + threshold, + class_reduction: Default::default(), + }, ..Default::default() } } @@ -73,10 +44,11 @@ impl PrecisionMetric { /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). #[allow(dead_code)] - pub fn multiclass(top_k: usize) -> Self { + pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { - config: PrecisionConfig::Multiclass { + config: ClassificationConfig::Multiclass { top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + class_reduction, }, ..Default::default() } @@ -88,20 +60,16 @@ impl PrecisionMetric { /// /// * `threshold` - The threshold to transform a probability into a binary prediction. #[allow(dead_code)] - pub fn multilabel(threshold: f64) -> Self { + pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { - config: PrecisionConfig::Multilabel { threshold }, + config: ClassificationConfig::Multilabel { + threshold, + class_reduction, + }, ..Default::default() } } - /// Sets the class reduction method. - #[allow(dead_code)] - pub fn with_class_reduction(mut self, class_reduction: ClassReduction) -> Self { - self.class_reduction = class_reduction; - self - } - fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { use ClassReduction::*; let avg_tensor = match self.class_reduction { @@ -122,21 +90,13 @@ impl PrecisionMetric { impl Metric for PrecisionMetric { const NAME: &'static str = "Precision"; - type Input = PrecisionInput; + type Input = ClassificationInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { let (predictions, targets) = input.clone().into(); let [sample_size, _] = input.predictions.dims(); - let (threshold, top_k) = match self.config { - PrecisionConfig::Binary { threshold } | PrecisionConfig::Multilabel { threshold } => { - (Some(threshold), None) - } - PrecisionConfig::Multiclass { top_k } => (None, Some(top_k)), - }; - - let cf_stats = - ConfusionStats::new(predictions, targets, threshold, top_k, self.class_reduction); + let cf_stats = ConfusionStats::new(predictions, targets, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive()); @@ -169,15 +129,13 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_micro(Micro, THRESHOLD, 0.5)] - #[case::binary_macro(Macro, THRESHOLD, 0.5)] + #[case::binary_macro(THRESHOLD, 0.5)] fn test_binary_precision( - #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Binary).into(); - let mut metric = PrecisionMetric::binary(threshold).with_class_reduction(class_reduction); + let mut metric = PrecisionMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) @@ -194,7 +152,7 @@ mod tests { #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multiclass).into(); - let mut metric = PrecisionMetric::multiclass(top_k).with_class_reduction(class_reduction); + let mut metric = PrecisionMetric::multiclass(top_k, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) @@ -210,7 +168,7 @@ mod tests { ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); let mut metric = - PrecisionMetric::multilabel(threshold).with_class_reduction(class_reduction); + PrecisionMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) From d99fe2804120746451ba9d0b9a4b19b1f96b99d2 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Thu, 21 Nov 2024 00:03:05 +0100 Subject: [PATCH 2/9] adjust comments in dummy data. implement recall. optimize imports and rename test function properly in precision. --- crates/burn-train/src/lib.rs | 20 ++- crates/burn-train/src/metric/mod.rs | 4 + crates/burn-train/src/metric/precision.rs | 13 +- crates/burn-train/src/metric/recall.rs | 170 ++++++++++++++++++++++ 4 files changed, 190 insertions(+), 17 deletions(-) create mode 100644 crates/burn-train/src/metric/recall.rs diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index d47420750d..7401f4c4fc 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -63,12 +63,11 @@ pub(crate) mod tests { match classification_type { ClassificationType::Binary => { ( - Tensor::from_data( - [[0.3], [0.2], [0.7], [0.1], [0.55]], - //[[0], [0], [1], [0], [1]] with threshold=0.5 - &Default::default(), - ), + Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()), + // targets Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()), + // predictions @ threshold=0.5 + // [[0], [0], [1], [0], [1]] ) } ClassificationType::Multiclass => { @@ -81,12 +80,15 @@ pub(crate) mod tests { [0.1, 0.15, 0.8], [0.9, 0.03, 0.07], ], - //[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1 - //[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2 &Default::default(), ), Tensor::from_data( + // targets [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]], + // predictions @ top_k=1 + // [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] + // predictions @ top_k=2 + // [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] &Default::default(), ), ) @@ -101,11 +103,13 @@ pub(crate) mod tests { [0.7, 0.5, 0.9], [1.0, 0.3, 0.2], ], - //[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5 &Default::default(), ), + // targets Tensor::from_data( [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]], + // predictions @ threshold=0.5 + // [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] &Default::default(), ), ) diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 2b8d9cd801..189f3d2bb6 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -52,3 +52,7 @@ mod confusion_stats; mod precision; #[cfg(feature = "metrics")] pub use precision::*; +#[cfg(feature = "metrics")] +mod recall; +#[cfg(feature = "metrics")] +pub use recall::*; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 12ea7dee6f..e100dab063 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -1,10 +1,9 @@ use super::{ - classification::{ClassReduction, ClassificationConfig}, + classification::{ClassReduction, ClassificationConfig, ClassificationInput}, confusion_stats::ConfusionStats, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, }; -use crate::metric::classification::ClassificationInput; use burn_core::{ prelude::{Backend, Tensor}, tensor::cast::ToElement, @@ -130,10 +129,7 @@ mod tests { #[rstest] #[case::binary_macro(THRESHOLD, 0.5)] - fn test_binary_precision( - #[case] threshold: f64, - #[case] expected: f64, - ) { + fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = PrecisionMetric::binary(threshold); let _entry = metric.update(&input, &MetricMetadata::fake()); @@ -161,14 +157,13 @@ mod tests { #[rstest] #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)] #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)] - fn test_precision( + fn test_multilabel_precision( #[case] class_reduction: ClassReduction, #[case] threshold: f64, #[case] expected: f64, ) { let input = dummy_classification_input(&ClassificationType::Multilabel).into(); - let mut metric = - PrecisionMetric::multilabel(threshold, class_reduction); + let mut metric = PrecisionMetric::multilabel(threshold, class_reduction); let _entry = metric.update(&input, &MetricMetadata::fake()); TensorData::from([metric.value()]) .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs new file mode 100644 index 0000000000..aa5c55633c --- /dev/null +++ b/crates/burn-train/src/metric/recall.rs @@ -0,0 +1,170 @@ +use super::{ + classification::{ClassReduction, ClassificationConfig, ClassificationInput}, + confusion_stats::ConfusionStats, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::cast::ToElement, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +///The Precision Metric +#[derive(Default)] +pub struct RecallMetric { + state: NumericMetricState, + _b: PhantomData, + class_reduction: ClassReduction, + config: ClassificationConfig, +} + +impl RecallMetric { + /// Recall metric for binary classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(threshold: f64) -> Self { + Self { + config: ClassificationConfig::Binary { + threshold, + class_reduction: Default::default(), + }, + ..Default::default() + } + } + + /// Recall metric for multiclass classification. + /// + /// # Arguments + /// + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + #[allow(dead_code)] + pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationConfig::Multiclass { + top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + class_reduction, + }, + ..Default::default() + } + } + + /// Recall metric for multi-label classification. + /// + /// # Arguments + /// + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationConfig::Multilabel { + threshold, + class_reduction, + }, + ..Default::default() + } + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::*; + let avg_tensor = match self.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for RecallMetric { + const NAME: &'static str = "Recall"; + type Input = ClassificationInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let (predictions, targets) = input.clone().into(); + let [sample_size, _] = input.predictions.dims(); + + let cf_stats = ConfusionStats::new(predictions, targets, &self.config); + let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive()); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for RecallMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + Metric, MetricMetadata, Numeric, RecallMetric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_macro(THRESHOLD, 0.5)] + fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = RecallMetric::binary(threshold); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)] + #[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)] + #[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)] + #[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)] + fn test_multiclass_recall( + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = RecallMetric::multiclass(top_k, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)] + #[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)] + fn test_multilabel_recall( + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = RecallMetric::multilabel(threshold, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +} From 0880a6c2aab0bfcdfc0fae37fd80eaedeecb7fcc Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Thu, 21 Nov 2024 00:04:49 +0100 Subject: [PATCH 3/9] update book --- burn-book/src/building-blocks/metric.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index bdaef635d0..e029aca708 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -4,10 +4,11 @@ When working with the learner, you have the option to record metrics that will b throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | -| ---------------- | ------------------------------------------------------- | +|------------------|---------------------------------------------------------| | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | +| Recall | Calculate recall in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | From 29014723d83d135ef42b45505b073777baee7437 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Sun, 8 Dec 2024 19:24:34 +0100 Subject: [PATCH 4/9] rename ClassificationInput, compose ClassificationConfig as classification decision rule and class_reduction --- .../burn-train/src/learner/classification.rs | 18 +- crates/burn-train/src/lib.rs | 11 - .../burn-train/src/metric/classification.rs | 47 +-- .../burn-train/src/metric/confusion_stats.rs | 334 ++++++++++-------- crates/burn-train/src/metric/mod.rs | 1 + crates/burn-train/src/metric/precision.rs | 34 +- crates/burn-train/src/metric/recall.rs | 30 +- 7 files changed, 239 insertions(+), 236 deletions(-) diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index 354a329e47..40f2043d5a 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,5 +1,5 @@ -use crate::metric::{ - classification::ClassificationInput, AccuracyInput, Adaptor, HammingScoreInput, LossInput, +use crate::metric::{ ConfusionStatsInput, + AccuracyInput, Adaptor, HammingScoreInput, LossInput, }; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; @@ -29,16 +29,16 @@ impl Adaptor> for ClassificationOutput { } } -impl Adaptor> for ClassificationOutput { - fn adapt(&self) -> ClassificationInput { +impl Adaptor> for ClassificationOutput { + fn adapt(&self) -> ConfusionStatsInput { let [_, num_classes] = self.output.dims(); if num_classes > 1 { - ClassificationInput::new( + ConfusionStatsInput::new( self.output.clone(), self.targets.clone().one_hot(num_classes).bool(), ) } else { - ClassificationInput::new( + ConfusionStatsInput::new( self.output.clone(), self.targets.clone().unsqueeze_dim(1).bool(), ) @@ -71,8 +71,8 @@ impl Adaptor> for MultiLabelClassificationOutput { } } -impl Adaptor> for MultiLabelClassificationOutput { - fn adapt(&self) -> ClassificationInput { - ClassificationInput::new(self.output.clone(), self.targets.clone().bool()) +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> ConfusionStatsInput { + ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool()) } } diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs index 7401f4c4fc..5c06748e29 100644 --- a/crates/burn-train/src/lib.rs +++ b/crates/burn-train/src/lib.rs @@ -29,7 +29,6 @@ pub(crate) type TestBackend = burn_ndarray::NdArray; #[cfg(test)] pub(crate) mod tests { - use crate::metric::classification::ClassificationConfig; use crate::TestBackend; use burn_core::{prelude::Tensor, tensor::Bool}; use std::default::Default; @@ -45,16 +44,6 @@ pub(crate) mod tests { Multilabel, } - impl ClassificationType { - pub fn from_classification_config(config: &ClassificationConfig) -> Self { - match config { - ClassificationConfig::Binary { .. } => ClassificationType::Binary, - ClassificationConfig::Multiclass { .. } => ClassificationType::Multiclass, - ClassificationConfig::Multilabel { .. } => ClassificationType::Multilabel, - } - } - } - /// Sample x Class shaped matrix for use in /// classification metrics testing pub fn dummy_classification_input( diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs index 3b3e19df61..fde5a491b0 100644 --- a/crates/burn-train/src/metric/classification.rs +++ b/crates/burn-train/src/metric/classification.rs @@ -1,4 +1,3 @@ -use burn_core::prelude::{Backend, Bool, Tensor}; use std::num::NonZeroUsize; /// The reduction strategy for classification metrics. @@ -11,47 +10,19 @@ pub enum ClassReduction { Macro, } -/// Input for classification metrics -#[derive(new, Debug, Clone)] -pub struct ClassificationInput { - /// Sample x Class Non thresholded normalized predictions. - pub predictions: Tensor, - /// Sample x Class one-hot encoded target. - pub targets: Tensor, +#[derive(Default)] +pub struct ClassificationMetricConfig { + pub decision_rule: ClassificationDecisionRule, + pub class_reduction: ClassReduction, } -impl From> for (Tensor, Tensor) { - fn from(input: ClassificationInput) -> Self { - (input.predictions, input.targets) - } -} - -impl From<(Tensor, Tensor)> for ClassificationInput { - fn from(value: (Tensor, Tensor)) -> Self { - Self::new(value.0, value.1) - } -} - -pub enum ClassificationConfig { - Binary { - threshold: f64, - class_reduction: ClassReduction, - }, - Multiclass { - top_k: NonZeroUsize, - class_reduction: ClassReduction, - }, - Multilabel { - threshold: f64, - class_reduction: ClassReduction, - }, +pub enum ClassificationDecisionRule { + Threshold(f64), + TopK(NonZeroUsize), } -impl Default for ClassificationConfig { +impl Default for ClassificationDecisionRule { fn default() -> Self { - Self::Binary { - threshold: 0.5, - class_reduction: Default::default(), - } + Self::Threshold(0.5) } } diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index db80976db7..e806510446 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -1,7 +1,30 @@ -use super::classification::{ClassReduction, ClassificationConfig}; +use super::classification::{ + ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig, +}; use burn_core::prelude::{Backend, Bool, Int, Tensor}; use std::fmt::{self, Debug}; +/// Input for classification metrics +#[derive(new, Debug, Clone)] +pub struct ConfusionStatsInput { + /// Sample x Class Non thresholded normalized predictions. + pub predictions: Tensor, + /// Sample x Class one-hot encoded target. + pub targets: Tensor, +} + +impl From> for (Tensor, Tensor) { + fn from(input: ConfusionStatsInput) -> Self { + (input.predictions, input.targets) + } +} + +impl From<(Tensor, Tensor)> for ConfusionStatsInput { + fn from(value: (Tensor, Tensor)) -> Self { + Self::new(value.0, value.1) + } +} + #[derive(Clone)] pub struct ConfusionStats { confusion_classes: Tensor, @@ -30,33 +53,26 @@ impl Debug for ConfusionStats { impl ConfusionStats { /// Expects `predictions` to be normalized. - pub fn new( - predictions: Tensor, - targets: Tensor, - config: &ClassificationConfig, - ) -> Self { - let (prediction_mask, class_reduction) = match config { - ClassificationConfig::Binary { - threshold, - class_reduction, + pub fn new(input: &ConfusionStatsInput, config: &ClassificationMetricConfig) -> Self { + let prediction_mask = match config.decision_rule { + ClassificationDecisionRule::Threshold(threshold) => { + input.predictions.clone().greater_elem(threshold) } - | ClassificationConfig::Multilabel { - threshold, - class_reduction, - } => (predictions.greater_elem(*threshold), *class_reduction), - ClassificationConfig::Multiclass { - top_k, - class_reduction, - } => { - let mask = predictions.zeros_like(); - let indexes = predictions.argsort_descending(1).narrow(1, 0, top_k.get()); + ClassificationDecisionRule::TopK(top_k) => { + let mask = input.predictions.zeros_like(); + let indexes = + input + .predictions + .clone() + .argsort_descending(1) + .narrow(1, 0, top_k.get()); let values = indexes.ones_like().float(); - (mask.scatter(1, indexes, values).bool(), *class_reduction) + mask.scatter(1, indexes, values).bool() } }; Self { - confusion_classes: prediction_mask.int() + targets.int() * 2, - class_reduction, + confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2, + class_reduction: config.class_reduction, } } @@ -111,84 +127,86 @@ impl ConfusionStats { #[cfg(test)] mod tests { - use super::ConfusionStats; - use crate::metric::classification::{ClassReduction::*, ClassificationConfig}; - use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use super::{ConfusionStats, ConfusionStatsInput}; + use crate::{ + metric::classification::{ + ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig, + }, + tests::{dummy_classification_input, ClassificationType, THRESHOLD}, + TestBackend, + }; use burn_core::prelude::TensorData; use rstest::{fixture, rstest}; use std::num::NonZeroUsize; - // binary classification results are the same independently of class_reduction - #[fixture] - #[once] - fn binary_config() -> ClassificationConfig { - ClassificationConfig::Binary { - threshold: THRESHOLD, - class_reduction: Default::default(), + fn top_k_config( + top_k: NonZeroUsize, + class_reduction: ClassReduction, + ) -> ClassificationMetricConfig { + ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::TopK(top_k), + class_reduction, } } #[fixture] #[once] - fn multiclass_config_k1_micro() -> ClassificationConfig { - ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(1).unwrap(), - class_reduction: Micro, - } + fn top_k_config_k1_micro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro) } + #[fixture] #[once] - fn multiclass_config_k1_macro() -> ClassificationConfig { - ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(1).unwrap(), - class_reduction: Macro, - } + fn top_k_config_k1_macro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro) } #[fixture] #[once] - fn multiclass_config_k2_micro() -> ClassificationConfig { - ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(2).unwrap(), - class_reduction: Micro, - } + fn top_k_config_k2_micro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro) } #[fixture] #[once] - fn multiclass_config_k2_macro() -> ClassificationConfig { - ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(2).unwrap(), - class_reduction: Macro, + fn top_k_config_k2_macro() -> ClassificationMetricConfig { + top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro) + } + + fn threshold_config( + threshold: f64, + class_reduction: ClassReduction, + ) -> ClassificationMetricConfig { + ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::Threshold(threshold), + class_reduction, } } #[fixture] #[once] - fn multilabel_config_micro() -> ClassificationConfig { - ClassificationConfig::Multilabel { - threshold: THRESHOLD, - class_reduction: Micro, - } + fn threshold_config_micro() -> ClassificationMetricConfig { + threshold_config(THRESHOLD, ClassReduction::Micro) } #[fixture] #[once] - fn multilabel_config_macro() -> ClassificationConfig { - ClassificationConfig::Multilabel { - threshold: THRESHOLD, - class_reduction: Macro, - } + fn threshold_config_macro() -> ClassificationMetricConfig { + threshold_config(THRESHOLD, ClassReduction::Macro) } #[rstest] - #[case::binary_micro(binary_config(), [1].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [3].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 1, 1].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [4].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 1, 1].into())] - #[case::multilabel_micro(multilabel_config_micro(), [5].into())] - #[case::multilabel_macro(multilabel_config_macro(), [2, 2, 1].into())] - fn test_true_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())] + fn test_true_positive( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .true_positive() .int() .into_data() @@ -196,18 +214,22 @@ mod tests { } #[rstest] - #[case::binary_macro(binary_config(), [2].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [8].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 3, 3].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [4].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [1, 1, 2].into())] - #[case::multilabel_micro(multilabel_config_micro(), [3].into())] - #[case::multilabel_macro(multilabel_config_macro(), [0, 2, 1].into())] - fn test_true_negative(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())] + fn test_true_negative( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .true_negative() .int() .into_data() @@ -215,18 +237,22 @@ mod tests { } #[rstest] - #[case::binary_macro(binary_config(), [1].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [2].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 1, 0].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [6].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 3, 1].into())] - #[case::multilabel_micro(multilabel_config_micro(), [3].into())] - #[case::multilabel_macro(multilabel_config_macro(), [1, 1, 1].into())] - fn test_false_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())] + fn test_false_positive( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .false_positive() .int() .into_data() @@ -234,18 +260,22 @@ mod tests { } #[rstest] - #[case::binary_macro(binary_config(), [1].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [2].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [1, 0, 1].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [1].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [0, 0, 1].into())] - #[case::multilabel_micro(multilabel_config_micro(), [4].into())] - #[case::multilabel_macro(multilabel_config_macro(), [2, 0, 2].into())] - fn test_false_negatives(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())] + fn test_false_negatives( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .false_negative() .int() .into_data() @@ -253,18 +283,22 @@ mod tests { } #[rstest] - #[case::binary_micro(binary_config(), [2].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [5].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 1, 2].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [5].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [2, 1, 2].into())] - #[case::multilabel_micro(multilabel_config_micro(), [9].into())] - #[case::multilabel_macro(multilabel_config_macro(), [4, 2, 3].into())] - fn test_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())] + fn test_positive( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .positive() .int() .into_data() @@ -272,18 +306,22 @@ mod tests { } #[rstest] - #[case::binary_micro(binary_config(), [3].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [10].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [3, 4, 3].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [10].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [3, 4, 3].into())] - #[case::multilabel_micro(multilabel_config_micro(), [6].into())] - #[case::multilabel_macro(multilabel_config_macro(), [1, 3, 2].into())] - fn test_negative(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())] + fn test_negative( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .negative() .int() .into_data() @@ -291,18 +329,22 @@ mod tests { } #[rstest] - #[case::binary_micro(binary_config(), [2].into())] - #[case::multiclass_micro(multiclass_config_k1_micro(), [5].into())] - #[case::multiclass_macro(multiclass_config_k1_macro(), [2, 2, 1].into())] - #[case::multiclass_micro(multiclass_config_k2_micro(), [10].into())] - #[case::multiclass_macro(multiclass_config_k2_macro(), [4, 4, 2].into())] - #[case::multilabel_micro(multilabel_config_micro(), [8].into())] - #[case::multilabel_macro(multilabel_config_macro(), [3, 3, 2].into())] - fn test_predicted_positive(#[case] config: ClassificationConfig, #[case] expected: Vec) { - let (predictions, targets) = - dummy_classification_input(&ClassificationType::from_classification_config(&config)) - .into(); - ConfusionStats::new(predictions, targets, &config) + #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())] + #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())] + #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())] + #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())] + #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())] + #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())] + fn test_predicted_positive( + #[case] classification_type: ClassificationType, + #[case] config: ClassificationMetricConfig, + #[case] expected: Vec, + ) { + let input: ConfusionStatsInput = + dummy_classification_input(&classification_type).into(); + ConfusionStats::new(&input, &config) .predicted_positive() .int() .into_data() diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 189f3d2bb6..ae35ba5221 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -48,6 +48,7 @@ pub(crate) mod classification; #[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; mod confusion_stats; +pub use confusion_stats::ConfusionStatsInput; #[cfg(feature = "metrics")] mod precision; #[cfg(feature = "metrics")] diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index e100dab063..392861feb5 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -1,6 +1,6 @@ use super::{ - classification::{ClassReduction, ClassificationConfig, ClassificationInput}, - confusion_stats::ConfusionStats, + classification::{ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, }; @@ -16,8 +16,7 @@ use std::num::NonZeroUsize; pub struct PrecisionMetric { state: NumericMetricState, _b: PhantomData, - class_reduction: ClassReduction, - config: ClassificationConfig, + config: ClassificationMetricConfig, } impl PrecisionMetric { @@ -29,9 +28,9 @@ impl PrecisionMetric { #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self { - config: ClassificationConfig::Binary { - threshold, - class_reduction: Default::default(), + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::Threshold(threshold), + ..Default::default() }, ..Default::default() } @@ -45,8 +44,10 @@ impl PrecisionMetric { #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { - config: ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), class_reduction, }, ..Default::default() @@ -57,12 +58,12 @@ impl PrecisionMetric { /// /// # Arguments /// - /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `threshold` - The threshold to transform a probability into a binary value. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { - config: ClassificationConfig::Multilabel { - threshold, + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::Threshold(threshold), class_reduction, }, ..Default::default() @@ -70,8 +71,8 @@ impl PrecisionMetric { } fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { - use ClassReduction::*; - let avg_tensor = match self.class_reduction { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric.contains_nan().any().into_scalar() { @@ -89,13 +90,12 @@ impl PrecisionMetric { impl Metric for PrecisionMetric { const NAME: &'static str = "Precision"; - type Input = ClassificationInput; + type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - let (predictions, targets) = input.clone().into(); let [sample_size, _] = input.predictions.dims(); - let cf_stats = ConfusionStats::new(predictions, targets, &self.config); + let cf_stats = ConfusionStats::new(input, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive()); diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index aa5c55633c..2e57b8b7f0 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -1,6 +1,6 @@ use super::{ - classification::{ClassReduction, ClassificationConfig, ClassificationInput}, - confusion_stats::ConfusionStats, + classification::{ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, }; @@ -16,8 +16,7 @@ use std::num::NonZeroUsize; pub struct RecallMetric { state: NumericMetricState, _b: PhantomData, - class_reduction: ClassReduction, - config: ClassificationConfig, + config: ClassificationMetricConfig, } impl RecallMetric { @@ -29,9 +28,9 @@ impl RecallMetric { #[allow(dead_code)] pub fn binary(threshold: f64) -> Self { Self { - config: ClassificationConfig::Binary { - threshold, - class_reduction: Default::default(), + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::Threshold(threshold), + ..Default::default() }, ..Default::default() } @@ -45,8 +44,10 @@ impl RecallMetric { #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { - config: ClassificationConfig::Multiclass { - top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), class_reduction, }, ..Default::default() @@ -61,8 +62,8 @@ impl RecallMetric { #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { - config: ClassificationConfig::Multilabel { - threshold, + config: ClassificationMetricConfig { + decision_rule: ClassificationDecisionRule::Threshold(threshold), class_reduction, }, ..Default::default() @@ -71,7 +72,7 @@ impl RecallMetric { fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { use ClassReduction::*; - let avg_tensor = match self.class_reduction { + let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { if aggregated_metric.contains_nan().any().into_scalar() { @@ -89,13 +90,12 @@ impl RecallMetric { impl Metric for RecallMetric { const NAME: &'static str = "Recall"; - type Input = ClassificationInput; + type Input = ConfusionStatsInput; fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { - let (predictions, targets) = input.clone().into(); let [sample_size, _] = input.predictions.dims(); - let cf_stats = ConfusionStats::new(predictions, targets, &self.config); + let cf_stats = ConfusionStats::new(input, &self.config); let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive()); self.state.update( From f6514c1d21e9584b71f72eb18a5c8691e01af33d Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Sun, 8 Dec 2024 19:56:23 +0100 Subject: [PATCH 5/9] format --- crates/burn-train/src/learner/classification.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index 72c23f25bd..14551169cb 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,5 +1,5 @@ -use crate::metric::{ processor::ItemLazy, ConfusionStatsInput, - AccuracyInput, Adaptor, HammingScoreInput, LossInput, +use crate::metric::{ + processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput, }; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor, Transaction}; From f383ba3ff2bd022783ab5630272f648664348029 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Sun, 8 Dec 2024 20:04:56 +0100 Subject: [PATCH 6/9] no glob import --- crates/burn-train/src/metric/confusion_stats.rs | 2 +- crates/burn-train/src/metric/recall.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index e806510446..5b3f6b071e 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -81,7 +81,7 @@ impl ConfusionStats { sample_class_mask: Tensor, class_reduction: ClassReduction, ) -> Tensor { - use ClassReduction::*; + use ClassReduction::{Micro, Macro}; match class_reduction { Micro => sample_class_mask.float().sum(), Macro => sample_class_mask.float().sum_dim(0).squeeze(0), diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index 2e57b8b7f0..337f8167df 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -71,7 +71,7 @@ impl RecallMetric { } fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { - use ClassReduction::*; + use ClassReduction::{Macro, Micro}; let avg_tensor = match self.config.class_reduction { Micro => aggregated_metric, Macro => { From d7c3e90237cbaacdf34338ada8f117f09b266bcb Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Sun, 8 Dec 2024 20:12:41 +0100 Subject: [PATCH 7/9] add comment and fmt --- crates/burn-train/src/metric/confusion_stats.rs | 2 +- crates/burn-train/src/metric/precision.rs | 1 + crates/burn-train/src/metric/recall.rs | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index 5b3f6b071e..8ebc945723 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -81,7 +81,7 @@ impl ConfusionStats { sample_class_mask: Tensor, class_reduction: ClassReduction, ) -> Tensor { - use ClassReduction::{Micro, Macro}; + use ClassReduction::{Macro, Micro}; match class_reduction { Micro => sample_class_mask.float().sum(), Macro => sample_class_mask.float().sum_dim(0).squeeze(0), diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 392861feb5..1dec7b8fa9 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -30,6 +30,7 @@ impl PrecisionMetric { Self { config: ClassificationMetricConfig { decision_rule: ClassificationDecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction ..Default::default() }, ..Default::default() diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index 337f8167df..cbc009502e 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -30,6 +30,7 @@ impl RecallMetric { Self { config: ClassificationMetricConfig { decision_rule: ClassificationDecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction ..Default::default() }, ..Default::default() From 924599bc4fb82fdb73728924b4e5eb451e8e2afd Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Wed, 18 Dec 2024 18:19:03 +0100 Subject: [PATCH 8/9] rename struct slight reorganization add comments --- .../burn-train/src/metric/classification.rs | 28 ++++++++++--------- .../burn-train/src/metric/confusion_stats.rs | 20 +++++-------- crates/burn-train/src/metric/precision.rs | 8 +++--- crates/burn-train/src/metric/recall.rs | 8 +++--- 4 files changed, 30 insertions(+), 34 deletions(-) diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs index fde5a491b0..5fe6f5e270 100644 --- a/crates/burn-train/src/metric/classification.rs +++ b/crates/burn-train/src/metric/classification.rs @@ -1,28 +1,30 @@ use std::num::NonZeroUsize; -/// The reduction strategy for classification metrics. -#[derive(Copy, Clone, Default)] -pub enum ClassReduction { - /// Computes the statistics over all classes before averaging - Micro, - /// Computes the statistics independently for each class before averaging - #[default] - Macro, -} - +/// Necessary data for classification metrics. #[derive(Default)] pub struct ClassificationMetricConfig { - pub decision_rule: ClassificationDecisionRule, + pub decision_rule: DecisionRule, pub class_reduction: ClassReduction, } -pub enum ClassificationDecisionRule { +/// The prediction decision rule for classification metrics. +pub enum DecisionRule { Threshold(f64), TopK(NonZeroUsize), } -impl Default for ClassificationDecisionRule { +impl Default for DecisionRule { fn default() -> Self { Self::Threshold(0.5) } } + +/// The reduction strategy for classification metrics. +#[derive(Copy, Clone, Default)] +pub enum ClassReduction { + /// Computes the statistics over all classes before averaging + Micro, + /// Computes the statistics independently for each class before averaging + #[default] + Macro, +} diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs index 8ebc945723..3a3047a428 100644 --- a/crates/burn-train/src/metric/confusion_stats.rs +++ b/crates/burn-train/src/metric/confusion_stats.rs @@ -1,10 +1,8 @@ -use super::classification::{ - ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig, -}; +use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}; use burn_core::prelude::{Backend, Bool, Int, Tensor}; use std::fmt::{self, Debug}; -/// Input for classification metrics +/// Input for [ConfusionStats] #[derive(new, Debug, Clone)] pub struct ConfusionStatsInput { /// Sample x Class Non thresholded normalized predictions. @@ -55,10 +53,8 @@ impl ConfusionStats { /// Expects `predictions` to be normalized. pub fn new(input: &ConfusionStatsInput, config: &ClassificationMetricConfig) -> Self { let prediction_mask = match config.decision_rule { - ClassificationDecisionRule::Threshold(threshold) => { - input.predictions.clone().greater_elem(threshold) - } - ClassificationDecisionRule::TopK(top_k) => { + DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold), + DecisionRule::TopK(top_k) => { let mask = input.predictions.zeros_like(); let indexes = input @@ -129,9 +125,7 @@ impl ConfusionStats { mod tests { use super::{ConfusionStats, ConfusionStatsInput}; use crate::{ - metric::classification::{ - ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig, - }, + metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, tests::{dummy_classification_input, ClassificationType, THRESHOLD}, TestBackend, }; @@ -144,7 +138,7 @@ mod tests { class_reduction: ClassReduction, ) -> ClassificationMetricConfig { ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::TopK(top_k), + decision_rule: DecisionRule::TopK(top_k), class_reduction, } } @@ -175,7 +169,7 @@ mod tests { class_reduction: ClassReduction, ) -> ClassificationMetricConfig { ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::Threshold(threshold), + decision_rule: DecisionRule::Threshold(threshold), class_reduction, } } diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 1dec7b8fa9..067261cbdf 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -1,5 +1,5 @@ use super::{ - classification::{ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig}, + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, @@ -29,7 +29,7 @@ impl PrecisionMetric { pub fn binary(threshold: f64) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::Threshold(threshold), + decision_rule: DecisionRule::Threshold(threshold), // binary classification results are the same independently of class_reduction ..Default::default() }, @@ -46,7 +46,7 @@ impl PrecisionMetric { pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::TopK( + decision_rule: DecisionRule::TopK( NonZeroUsize::new(top_k).expect("top_k must be non-zero"), ), class_reduction, @@ -64,7 +64,7 @@ impl PrecisionMetric { pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::Threshold(threshold), + decision_rule: DecisionRule::Threshold(threshold), class_reduction, }, ..Default::default() diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index cbc009502e..8ce4351396 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -1,5 +1,5 @@ use super::{ - classification::{ClassReduction, ClassificationDecisionRule, ClassificationMetricConfig}, + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, confusion_stats::{ConfusionStats, ConfusionStatsInput}, state::{FormatOptions, NumericMetricState}, Metric, MetricEntry, MetricMetadata, Numeric, @@ -29,7 +29,7 @@ impl RecallMetric { pub fn binary(threshold: f64) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::Threshold(threshold), + decision_rule: DecisionRule::Threshold(threshold), // binary classification results are the same independently of class_reduction ..Default::default() }, @@ -46,7 +46,7 @@ impl RecallMetric { pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::TopK( + decision_rule: DecisionRule::TopK( NonZeroUsize::new(top_k).expect("top_k must be non-zero"), ), class_reduction, @@ -64,7 +64,7 @@ impl RecallMetric { pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { config: ClassificationMetricConfig { - decision_rule: ClassificationDecisionRule::Threshold(threshold), + decision_rule: DecisionRule::Threshold(threshold), class_reduction, }, ..Default::default() From 6a63633f96658d03ef37c0e19f76e07aba67c5a8 Mon Sep 17 00:00:00 2001 From: Tiago Sanona Date: Thu, 19 Dec 2024 23:45:01 +0100 Subject: [PATCH 9/9] add documentation to DecisionRule's branches --- crates/burn-train/src/metric/classification.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs index 5fe6f5e270..1caa2dabea 100644 --- a/crates/burn-train/src/metric/classification.rs +++ b/crates/burn-train/src/metric/classification.rs @@ -9,7 +9,9 @@ pub struct ClassificationMetricConfig { /// The prediction decision rule for classification metrics. pub enum DecisionRule { + /// Consider a class predicted if its probability exceeds the threshold. Threshold(f64), + /// Consider a class predicted correctly if it is within the top k predicted classes based on scores. TopK(NonZeroUsize), }