Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recall #2518

Merged
merged 11 commits into from
Dec 20, 2024
3 changes: 2 additions & 1 deletion burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ When working with the learner, you have the option to record metrics that will b
throughout the training process. We currently offer a restricted range of metrics.

| Metric | Description |
| ---------------- | ------------------------------------------------------- |
|------------------|---------------------------------------------------------|
| Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Precision | Calculate precision in percentage |
| Recall | Calculate recall in percentage |
| AUROC | Calculate the area under curve of ROC in percentage |
| Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs |
Expand Down
19 changes: 10 additions & 9 deletions crates/burn-train/src/learner/classification.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::metric::processor::ItemLazy;
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput};
use crate::metric::{
processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput,
};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor, Transaction};
use burn_ndarray::NdArray;
Expand Down Expand Up @@ -51,16 +52,16 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
let [_, num_classes] = self.output.dims();
if num_classes > 1 {
PrecisionInput::new(
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().one_hot(num_classes).bool(),
)
} else {
PrecisionInput::new(
ConfusionStatsInput::new(
self.output.clone(),
self.targets.clone().unsqueeze_dim(1).bool(),
)
Expand Down Expand Up @@ -115,8 +116,8 @@ impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
}
}

impl<B: Backend> Adaptor<PrecisionInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> PrecisionInput<B> {
PrecisionInput::new(self.output.clone(), self.targets.clone().bool())
impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> ConfusionStatsInput<B> {
ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
}
}
23 changes: 14 additions & 9 deletions crates/burn-train/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ pub(crate) mod tests {
/// Probability of tp before adding errors
pub const THRESHOLD: f64 = 0.5;

#[derive(Debug)]
#[derive(Debug, Default)]
pub enum ClassificationType {
#[default]
Binary,
Multiclass,
Multilabel,
Expand All @@ -51,12 +52,11 @@ pub(crate) mod tests {
match classification_type {
ClassificationType::Binary => {
(
Tensor::from_data(
[[0.3], [0.2], [0.7], [0.1], [0.55]],
//[[0], [0], [1], [0], [1]] with threshold=0.5
&Default::default(),
),
Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
// targets
Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
// predictions @ threshold=0.5
// [[0], [0], [1], [0], [1]]
)
}
ClassificationType::Multiclass => {
Expand All @@ -69,12 +69,15 @@ pub(crate) mod tests {
[0.1, 0.15, 0.8],
[0.9, 0.03, 0.07],
],
//[[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]] with top_k=1
//[[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]] with top_k=2
&Default::default(),
),
Tensor::from_data(
// targets
[[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ top_k=1
// [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
// predictions @ top_k=2
// [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0, 1]]
&Default::default(),
),
)
Expand All @@ -89,11 +92,13 @@ pub(crate) mod tests {
[0.7, 0.5, 0.9],
[1.0, 0.3, 0.2],
],
//[[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]] with threshold=0.5
&Default::default(),
),
// targets
Tensor::from_data(
[[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
// predictions @ threshold=0.5
// [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
&Default::default(),
),
)
Expand Down
23 changes: 23 additions & 0 deletions crates/burn-train/src/metric/classification.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
use std::num::NonZeroUsize;

/// Necessary data for classification metrics.
#[derive(Default)]
pub struct ClassificationMetricConfig {
pub decision_rule: DecisionRule,
pub class_reduction: ClassReduction,
}

/// The prediction decision rule for classification metrics.
pub enum DecisionRule {
/// Consider a class predicted if its probability exceeds the threshold.
Threshold(f64),
/// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
TopK(NonZeroUsize),
}

impl Default for DecisionRule {
fn default() -> Self {
Self::Threshold(0.5)
}
}

/// The reduction strategy for classification metrics.
#[derive(Copy, Clone, Default)]
pub enum ClassReduction {
Expand Down
Loading
Loading