Skip to content

Commit

Permalink
Add recall (tracel-ai#2518)
Browse files Browse the repository at this point in the history
* create ClassificationInput and ClassificationConfig to share code with RecallMetric. Adapt confusionStats and Precision with generalized code.

* adjust comments in dummy data. implement recall. optimize imports and rename test function properly in precision.

* update book

* rename ClassificationInput, compose ClassificationConfig as classification decision rule and class_reduction

* format

* no glob import

* add comment and fmt

* rename struct slight reorganization add comments

* add documentation to DecisionRule's branches

---------

Co-authored-by: Tiago Sanona <[email protected]>
  • Loading branch information
tsanona and Tiago Sanona authored Dec 20, 2024
1 parent 9edeb67 commit 676877a
Show file tree
Hide file tree
Showing 8 changed files with 459 additions and 303 deletions.
3 changes: 2 additions & 1 deletion burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ When working with the learner, you have the option to record metrics that will b
throughout the training process. We currently offer a restricted range of metrics.

| Metric | Description |
| ---------------- | ------------------------------------------------------- |
|------------------|---------------------------------------------------------|
| Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Precision | Calculate precision in percentage |
| Recall | Calculate recall in percentage |
| AUROC | Calculate the area under curve of ROC in percentage |
| Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs |
Expand Down
19 changes: 10 additions & 9 deletions crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::metric::processor::ItemLazy;
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput};
use crate::metric::{
processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput,
};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor, Transaction};
use burn_ndarray::NdArray;
Expand Down Expand Up @@ -51,16 +52,16 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
let [_, num_classes] = self.output.dims();
if num_classes > 1 {
PrecisionInput::new(
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().one_hot(num_classes).bool(),
)
} else {
PrecisionInput::new(
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
Expand Down Expand Up @@ -115,8 +116,8 @@ impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
PrecisionInput::new(self.output.clone(), self.targets.clone().bool())
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
}
}
23 changes: 14 additions & 9 deletions crates/burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ pub(crate) mod tests {
/// Probability of tp before adding errors
pub const THRESHOLD: f64 = 0.5;

#[derive(Debug)]
#[derive(Debug, Default)]
pub enum ClassificationType {
#[default]
Binary,
Multiclass,
Multilabel,
Expand All @@ -51,12 +52,11 @@ pub(crate) mod tests {
match classification_type {
ClassificationType::Binary => {
(
Tensor::from_data(
[[0.3], [0.2], [0.7], [0.1], [0.55]],
//[[0], [0], [1], [0], [1]] with threshold=0.5
&Default::default(),
),
Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
// targets
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
// predictions @ threshold=0.5
// [[0], [0], [1], [0], [1]]
)
}
ClassificationType::Multiclass => {
Expand All @@ -69,12 +69,15 @@ pub(crate) mod tests {
[0.1, 0.15, 0.8],
[0.9, 0.03, 0.07],
],
//[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1
//[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2
&Default::default(),
),
Tensor::from_data(
// targets
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ top_k=1
// [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
// predictions @ top_k=2
// [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
&Default::default(),
),
)
Expand All @@ -89,11 +92,13 @@ pub(crate) mod tests {
[0.7, 0.5, 0.9],
[1.0, 0.3, 0.2],
],
//[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5
&Default::default(),
),
// targets
Tensor::from_data(
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ threshold=0.5
// [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
&Default::default(),
),
)
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-train/src/metric/classification.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
use std::num::NonZeroUsize;

/// Necessary data for classification metrics.
#[derive(Default)]
pub struct ClassificationMetricConfig {
pub decision_rule: DecisionRule,
pub class_reduction: ClassReduction,
}

/// The prediction decision rule for classification metrics.
pub enum DecisionRule {
/// Consider a class predicted if its probability exceeds the threshold.
Threshold(f64),
/// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
TopK(NonZeroUsize),
}

impl Default for DecisionRule {
fn default() -> Self {
Self::Threshold(0.5)
}
}

/// The reduction strategy for classification metrics.
#[derive(Copy, Clone, Default)]
pub enum ClassReduction {
Expand Down
Loading

0 comments on commit 676877a

Please sign in to comment.