diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md
index bdaef635d0..e029aca708 100644
--- a/burn-book/src/building-blocks/metric.md
+++ b/burn-book/src/building-blocks/metric.md
@@ -4,10 +4,11 @@ When working with the learner, you have the option to record metrics that will b
 throughout the training process. We currently offer a restricted range of metrics.
 
 | Metric           | Description                                             |
-| ---------------- | ------------------------------------------------------- |
+|------------------|---------------------------------------------------------|
 | Accuracy         | Calculate the accuracy in percentage                    |
 | TopKAccuracy     | Calculate the top-k accuracy in percentage              |
 | Precision        | Calculate precision in percentage                       |
+| Recall           | Calculate recall in percentage                          |
 | AUROC            | Calculate the area under curve of ROC in percentage     |
 | Loss             | Output the loss used for the backward pass              |
 | CPU Temperature  | Fetch the temperature of CPUs                           |
diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs
index 3eabc2a09b..14551169cb 100644
--- a/crates/burn-train/src/learner/classification.rs
+++ b/crates/burn-train/src/learner/classification.rs
@@ -1,5 +1,6 @@
-use crate::metric::processor::ItemLazy;
-use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput, PrecisionInput};
+use crate::metric::{
+    processor::ItemLazy, AccuracyInput, Adaptor, ConfusionStatsInput, HammingScoreInput, LossInput,
+};
 use burn_core::tensor::backend::Backend;
 use burn_core::tensor::{Int, Tensor, Transaction};
 use burn_ndarray::NdArray;
@@ -51,16 +52,16 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
     }
 }
 
-impl<B: Backend> Adaptor<PrecisionInput<B>> for ClassificationOutput<B> {
-    fn adapt(&self) -> PrecisionInput<B> {
+impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for ClassificationOutput<B> {
+    fn adapt(&self) -> ConfusionStatsInput<B> {
         let [_, num_classes] = self.output.dims();
         if num_classes > 1 {
-            PrecisionInput::new(
+            ConfusionStatsInput::new(
                 self.output.clone(),
                 self.targets.clone().one_hot(num_classes).bool(),
             )
         } else {
-            PrecisionInput::new(
+            ConfusionStatsInput::new(
                 self.output.clone(),
                 self.targets.clone().unsqueeze_dim(1).bool(),
             )
@@ -115,8 +116,8 @@ impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
     }
 }
 
-impl<B: Backend> Adaptor<PrecisionInput<B>> for MultiLabelClassificationOutput<B> {
-    fn adapt(&self) -> PrecisionInput<B> {
-        PrecisionInput::new(self.output.clone(), self.targets.clone().bool())
+impl<B: Backend> Adaptor<ConfusionStatsInput<B>> for MultiLabelClassificationOutput<B> {
+    fn adapt(&self) -> ConfusionStatsInput<B> {
+        ConfusionStatsInput::new(self.output.clone(), self.targets.clone().bool())
     }
 }
diff --git a/crates/burn-train/src/lib.rs b/crates/burn-train/src/lib.rs
index 24337498c5..5c06748e29 100644
--- a/crates/burn-train/src/lib.rs
+++ b/crates/burn-train/src/lib.rs
@@ -36,8 +36,9 @@ pub(crate) mod tests {
     /// Probability of tp before adding errors
     pub const THRESHOLD: f64 = 0.5;
 
-    #[derive(Debug)]
+    #[derive(Debug, Default)]
     pub enum ClassificationType {
+        #[default]
         Binary,
         Multiclass,
         Multilabel,
@@ -51,12 +52,11 @@ pub(crate) mod tests {
         match classification_type {
             ClassificationType::Binary => {
                 (
-                    Tensor::from_data(
-                        [[0.3], [0.2], [0.7], [0.1], [0.55]],
-                        //[[0],   [0],   [1],   [0],   [1]] with threshold=0.5
-                        &Default::default(),
-                    ),
+                    Tensor::from_data([[0.3], [0.2], [0.7], [0.1], [0.55]], &Default::default()),
+                    // targets
                     Tensor::from_data([[0], [1], [0], [0], [1]], &Default::default()),
+                    // predictions @ threshold=0.5
+                    //                     [[0], [0], [1], [0], [1]]
                 )
             }
             ClassificationType::Multiclass => {
@@ -69,12 +69,15 @@ pub(crate) mod tests {
                             [0.1, 0.15, 0.8],
                             [0.9, 0.03, 0.07],
                         ],
-                        //[[0,   1,   0],   [0,   1,   0],    [1,   0,   0],    [0,   0,   1],    [1,   0,    0]] with top_k=1
-                        //[[1,   1,   0],   [1,   1,   0],    [1,   1,   0],    [0,   1,   1],    [1,   0,    1]] with top_k=2
                         &Default::default(),
                     ),
                     Tensor::from_data(
+                        // targets
                         [[0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0]],
+                        // predictions @ top_k=1
+                        //   [[0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0,  0]]
+                        // predictions @ top_k=2
+                        //   [[1, 1, 0], [1, 1, 0], [1, 1, 0], [0, 1, 1], [1, 0,  1]]
                         &Default::default(),
                     ),
                 )
@@ -89,11 +92,13 @@ pub(crate) mod tests {
                             [0.7, 0.5, 0.9],
                             [1.0, 0.3, 0.2],
                         ],
-                        //[[0,   1,   1],   [0,   1,   0],    [1,   1,   0],   [1,   0,   1],   [1,   0,   0]] with threshold=0.5
                         &Default::default(),
                     ),
+                    // targets
                     Tensor::from_data(
                         [[1, 1, 0], [1, 0, 1], [1, 1, 1], [0, 0, 1], [1, 0, 0]],
+                        // predictions @ threshold=0.5
+                        //   [[0, 1, 1], [0, 1, 0], [1, 1, 0], [1, 0, 1], [1, 0, 0]]
                         &Default::default(),
                     ),
                 )
diff --git a/crates/burn-train/src/metric/classification.rs b/crates/burn-train/src/metric/classification.rs
index 1eb51a85d0..1caa2dabea 100644
--- a/crates/burn-train/src/metric/classification.rs
+++ b/crates/burn-train/src/metric/classification.rs
@@ -1,3 +1,26 @@
+use std::num::NonZeroUsize;
+
+/// Necessary data for classification metrics.
+#[derive(Default)]
+pub struct ClassificationMetricConfig {
+    pub decision_rule: DecisionRule,
+    pub class_reduction: ClassReduction,
+}
+
+/// The prediction decision rule for classification metrics.
+pub enum DecisionRule {
+    /// Consider a class predicted if its probability exceeds the threshold.
+    Threshold(f64),
+    /// Consider a class predicted correctly if it is within the top k predicted classes based on scores.
+    TopK(NonZeroUsize),
+}
+
+impl Default for DecisionRule {
+    fn default() -> Self {
+        Self::Threshold(0.5)
+    }
+}
+
 /// The reduction strategy for classification metrics.
 #[derive(Copy, Clone, Default)]
 pub enum ClassReduction {
diff --git a/crates/burn-train/src/metric/confusion_stats.rs b/crates/burn-train/src/metric/confusion_stats.rs
index b0248c698b..3a3047a428 100644
--- a/crates/burn-train/src/metric/confusion_stats.rs
+++ b/crates/burn-train/src/metric/confusion_stats.rs
@@ -1,7 +1,27 @@
-use super::classification::ClassReduction;
+use super::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule};
 use burn_core::prelude::{Backend, Bool, Int, Tensor};
 use std::fmt::{self, Debug};
-use std::num::NonZeroUsize;
+
+/// Input for [ConfusionStats]
+#[derive(new, Debug, Clone)]
+pub struct ConfusionStatsInput<B: Backend> {
+    /// Sample x Class Non thresholded normalized predictions.
+    pub predictions: Tensor<B, 2>,
+    /// Sample x Class one-hot encoded target.
+    pub targets: Tensor<B, 2, Bool>,
+}
+
+impl<B: Backend> From<ConfusionStatsInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
+    fn from(input: ConfusionStatsInput<B>) -> Self {
+        (input.predictions, input.targets)
+    }
+}
+
+impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for ConfusionStatsInput<B> {
+    fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
+        Self::new(value.0, value.1)
+    }
+}
 
 #[derive(Clone)]
 pub struct ConfusionStats<B: Backend> {
@@ -31,28 +51,24 @@ impl<B: Backend> Debug for ConfusionStats<B> {
 
 impl<B: Backend> ConfusionStats<B> {
     /// Expects `predictions` to be normalized.
-    pub fn new(
-        predictions: Tensor<B, 2>,
-        targets: Tensor<B, 2, Bool>,
-        threshold: Option<f64>,
-        top_k: Option<NonZeroUsize>,
-        class_reduction: ClassReduction,
-    ) -> Self {
-        let prediction_mask = match (threshold, top_k) {
-            (Some(threshold), None) => {
-                predictions.greater_elem(threshold)
-            },
-            (None, Some(top_k)) => {
-                let mask = predictions.zeros_like();
-                let indexes = predictions.argsort_descending(1).narrow(1, 0, top_k.get());
+    pub fn new(input: &ConfusionStatsInput<B>, config: &ClassificationMetricConfig) -> Self {
+        let prediction_mask = match config.decision_rule {
+            DecisionRule::Threshold(threshold) => input.predictions.clone().greater_elem(threshold),
+            DecisionRule::TopK(top_k) => {
+                let mask = input.predictions.zeros_like();
+                let indexes =
+                    input
+                        .predictions
+                        .clone()
+                        .argsort_descending(1)
+                        .narrow(1, 0, top_k.get());
                 let values = indexes.ones_like().float();
                 mask.scatter(1, indexes, values).bool()
             }
-            _ => panic!("Either threshold (for binary or multilabel) or top_k (for multiclass) must be set."),
         };
         Self {
-            confusion_classes: prediction_mask.int() + targets.int() * 2,
-            class_reduction,
+            confusion_classes: prediction_mask.int() + input.targets.clone().int() * 2,
+            class_reduction: config.class_reduction,
         }
     }
 
@@ -61,7 +77,7 @@ impl<B: Backend> ConfusionStats<B> {
         sample_class_mask: Tensor<B, 2, Bool>,
         class_reduction: ClassReduction,
     ) -> Tensor<B, 1> {
-        use ClassReduction::*;
+        use ClassReduction::{Macro, Micro};
         match class_reduction {
             Micro => sample_class_mask.float().sum(),
             Macro => sample_class_mask.float().sum_dim(0).squeeze(0),
@@ -107,245 +123,225 @@ impl<B: Backend> ConfusionStats<B> {
 
 #[cfg(test)]
 mod tests {
-    use super::{
-        ClassReduction::{self, *},
-        ConfusionStats,
-    };
-    use crate::tests::{
-        dummy_classification_input,
-        ClassificationType::{self, *},
-        THRESHOLD,
+    use super::{ConfusionStats, ConfusionStatsInput};
+    use crate::{
+        metric::classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
+        tests::{dummy_classification_input, ClassificationType, THRESHOLD},
+        TestBackend,
     };
     use burn_core::prelude::TensorData;
-    use rstest::rstest;
+    use rstest::{fixture, rstest};
     use std::num::NonZeroUsize;
 
-    #[rstest]
-    #[should_panic]
-    #[case::both_some(Some(THRESHOLD), Some(1))]
-    #[should_panic]
-    #[case::both_none(None, None)]
-    fn test_exclusive_threshold_top_k(
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
-    ) {
-        let (predictions, targets) = dummy_classification_input(&Binary);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            Micro,
-        );
+    fn top_k_config(
+        top_k: NonZeroUsize,
+        class_reduction: ClassReduction,
+    ) -> ClassificationMetricConfig {
+        ClassificationMetricConfig {
+            decision_rule: DecisionRule::TopK(top_k),
+            class_reduction,
+        }
+    }
+    #[fixture]
+    #[once]
+    fn top_k_config_k1_micro() -> ClassificationMetricConfig {
+        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Micro)
+    }
+
+    #[fixture]
+    #[once]
+    fn top_k_config_k1_macro() -> ClassificationMetricConfig {
+        top_k_config(NonZeroUsize::new(1).unwrap(), ClassReduction::Macro)
+    }
+    #[fixture]
+    #[once]
+    fn top_k_config_k2_micro() -> ClassificationMetricConfig {
+        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Micro)
+    }
+    #[fixture]
+    #[once]
+    fn top_k_config_k2_macro() -> ClassificationMetricConfig {
+        top_k_config(NonZeroUsize::new(2).unwrap(), ClassReduction::Macro)
+    }
+
+    fn threshold_config(
+        threshold: f64,
+        class_reduction: ClassReduction,
+    ) -> ClassificationMetricConfig {
+        ClassificationMetricConfig {
+            decision_rule: DecisionRule::Threshold(threshold),
+            class_reduction,
+        }
+    }
+    #[fixture]
+    #[once]
+    fn threshold_config_micro() -> ClassificationMetricConfig {
+        threshold_config(THRESHOLD, ClassReduction::Micro)
+    }
+    #[fixture]
+    #[once]
+    fn threshold_config_macro() -> ClassificationMetricConfig {
+        threshold_config(THRESHOLD, ClassReduction::Macro)
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [3].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 1].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [5].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 2, 1].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [3].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 1].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [5].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 2, 1].into())]
     fn test_true_positive(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .true_positive()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .true_positive()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [8].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 3, 3].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [4].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [1, 1, 2].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [0, 2, 1].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [8].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 3, 3].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [4].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [1, 1, 2].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [0, 2, 1].into())]
     fn test_true_negative(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .true_negative()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .true_negative()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 1, 0].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [6].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 3, 1].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [3].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 1, 1].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 1, 0].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [6].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 3, 1].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [3].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 1, 1].into())]
     fn test_false_positive(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .false_positive()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .false_positive()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [1].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [2].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [1, 0, 1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [1].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [0, 0, 1].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [4].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [2, 0, 2].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [1].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [2].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [1, 0, 1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [1].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [0, 0, 1].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [4].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [2, 0, 2].into())]
     fn test_false_negatives(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .false_negative()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .false_negative()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 1, 2].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [5].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [2, 1, 2].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [9].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [4, 2, 3].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 1, 2].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [5].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [2, 1, 2].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [9].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [4, 2, 3].into())]
     fn test_positive(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .positive()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .positive()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [3].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [3].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [10].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [3, 4, 3].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [3, 4, 3].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [6].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [1, 3, 2].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [3].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [3].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [10].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [3, 4, 3].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [3, 4, 3].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [6].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [1, 3, 2].into())]
     fn test_negative(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .negative()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .negative()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 
     #[rstest]
-    #[case::binary_micro(Binary, Micro, Some(THRESHOLD), None, [2].into())]
-    #[case::binary_macro(Binary, Macro, Some(THRESHOLD), None, [2].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(1), [5].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(1), [2, 2, 1].into())]
-    #[case::multiclass_micro(Multiclass, Micro, None, Some(2), [10].into())]
-    #[case::multiclass_macro(Multiclass, Macro, None, Some(2), [4, 4, 2].into())]
-    #[case::multilabel_micro(Multilabel, Micro, Some(THRESHOLD), None, [8].into())]
-    #[case::multilabel_macro(Multilabel, Macro, Some(THRESHOLD), None, [3, 3, 2].into())]
+    #[case::binary_micro(ClassificationType::Binary, threshold_config_micro(), [2].into())]
+    #[case::binary_macro(ClassificationType::Binary, threshold_config_macro(), [2].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k1_micro(), [5].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k1_macro(), [2, 2, 1].into())]
+    #[case::multiclass_micro(ClassificationType::Multiclass, top_k_config_k2_micro(), [10].into())]
+    #[case::multiclass_macro(ClassificationType::Multiclass, top_k_config_k2_macro(), [4, 4, 2].into())]
+    #[case::multilabel_micro(ClassificationType::Multilabel, threshold_config_micro(), [8].into())]
+    #[case::multilabel_macro(ClassificationType::Multilabel, threshold_config_macro(), [3, 3, 2].into())]
     fn test_predicted_positive(
         #[case] classification_type: ClassificationType,
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: Option<f64>,
-        #[case] top_k: Option<usize>,
+        #[case] config: ClassificationMetricConfig,
         #[case] expected: Vec<i64>,
     ) {
-        let (predictions, targets) = dummy_classification_input(&classification_type);
-        ConfusionStats::new(
-            predictions,
-            targets,
-            threshold,
-            top_k.and_then(NonZeroUsize::new),
-            class_reduction,
-        )
-        .predicted_positive()
-        .int()
-        .into_data()
-        .assert_eq(&TensorData::from(expected.as_slice()), true);
+        let input: ConfusionStatsInput<TestBackend> =
+            dummy_classification_input(&classification_type).into();
+        ConfusionStats::new(&input, &config)
+            .predicted_positive()
+            .int()
+            .into_data()
+            .assert_eq(&TensorData::from(expected.as_slice()), true);
     }
 }
diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs
index 33049f9a46..e6358e3023 100644
--- a/crates/burn-train/src/metric/mod.rs
+++ b/crates/burn-train/src/metric/mod.rs
@@ -51,7 +51,12 @@ pub(crate) mod classification;
 #[cfg(feature = "metrics")]
 pub use crate::metric::classification::ClassReduction;
 mod confusion_stats;
+pub use confusion_stats::ConfusionStatsInput;
 #[cfg(feature = "metrics")]
 mod precision;
 #[cfg(feature = "metrics")]
 pub use precision::*;
+#[cfg(feature = "metrics")]
+mod recall;
+#[cfg(feature = "metrics")]
+pub use recall::*;
diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs
index 0b9efb3d8d..067261cbdf 100644
--- a/crates/burn-train/src/metric/precision.rs
+++ b/crates/burn-train/src/metric/precision.rs
@@ -1,56 +1,22 @@
 use super::{
-    classification::ClassReduction,
-    confusion_stats::ConfusionStats,
+    classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
+    confusion_stats::{ConfusionStats, ConfusionStatsInput},
     state::{FormatOptions, NumericMetricState},
     Metric, MetricEntry, MetricMetadata, Numeric,
 };
 use burn_core::{
     prelude::{Backend, Tensor},
-    tensor::{cast::ToElement, Bool},
+    tensor::cast::ToElement,
 };
 use core::marker::PhantomData;
 use std::num::NonZeroUsize;
 
-/// Input for precision metric.
-#[derive(new, Debug, Clone)]
-pub struct PrecisionInput<B: Backend> {
-    /// Sample x Class Non thresholded normalized predictions.
-    pub predictions: Tensor<B, 2>,
-    /// Sample x Class one-hot encoded target.
-    pub targets: Tensor<B, 2, Bool>,
-}
-
-impl<B: Backend> From<PrecisionInput<B>> for (Tensor<B, 2>, Tensor<B, 2, Bool>) {
-    fn from(input: PrecisionInput<B>) -> Self {
-        (input.predictions, input.targets)
-    }
-}
-
-impl<B: Backend> From<(Tensor<B, 2>, Tensor<B, 2, Bool>)> for PrecisionInput<B> {
-    fn from(value: (Tensor<B, 2>, Tensor<B, 2, Bool>)) -> Self {
-        Self::new(value.0, value.1)
-    }
-}
-
-enum PrecisionConfig {
-    Binary { threshold: f64 },
-    Multiclass { top_k: NonZeroUsize },
-    Multilabel { threshold: f64 },
-}
-
-impl Default for PrecisionConfig {
-    fn default() -> Self {
-        Self::Binary { threshold: 0.5 }
-    }
-}
-
 ///The Precision Metric
 #[derive(Default)]
 pub struct PrecisionMetric<B: Backend> {
     state: NumericMetricState,
     _b: PhantomData<B>,
-    class_reduction: ClassReduction,
-    config: PrecisionConfig,
+    config: ClassificationMetricConfig,
 }
 
 impl<B: Backend> PrecisionMetric<B> {
@@ -62,7 +28,11 @@ impl<B: Backend> PrecisionMetric<B> {
     #[allow(dead_code)]
     pub fn binary(threshold: f64) -> Self {
         Self {
-            config: PrecisionConfig::Binary { threshold },
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::Threshold(threshold),
+                // binary classification results are the same independently of class_reduction
+                ..Default::default()
+            },
             ..Default::default()
         }
     }
@@ -73,10 +43,13 @@ impl<B: Backend> PrecisionMetric<B> {
     ///
     /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
     #[allow(dead_code)]
-    pub fn multiclass(top_k: usize) -> Self {
+    pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
         Self {
-            config: PrecisionConfig::Multiclass {
-                top_k: NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::TopK(
+                    NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
+                ),
+                class_reduction,
             },
             ..Default::default()
         }
@@ -86,25 +59,21 @@ impl<B: Backend> PrecisionMetric<B> {
     ///
     /// # Arguments
     ///
-    /// * `threshold` - The threshold to transform a probability into a binary prediction.
+    /// * `threshold` - The threshold to transform a probability into a binary value.
     #[allow(dead_code)]
-    pub fn multilabel(threshold: f64) -> Self {
+    pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
         Self {
-            config: PrecisionConfig::Multilabel { threshold },
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::Threshold(threshold),
+                class_reduction,
+            },
             ..Default::default()
         }
     }
 
-    /// Sets the class reduction method.
-    #[allow(dead_code)]
-    pub fn with_class_reduction(mut self, class_reduction: ClassReduction) -> Self {
-        self.class_reduction = class_reduction;
-        self
-    }
-
     fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
-        use ClassReduction::*;
-        let avg_tensor = match self.class_reduction {
+        use ClassReduction::{Macro, Micro};
+        let avg_tensor = match self.config.class_reduction {
             Micro => aggregated_metric,
             Macro => {
                 if aggregated_metric.contains_nan().any().into_scalar() {
@@ -122,21 +91,12 @@ impl<B: Backend> PrecisionMetric<B> {
 
 impl<B: Backend> Metric for PrecisionMetric<B> {
     const NAME: &'static str = "Precision";
-    type Input = PrecisionInput<B>;
+    type Input = ConfusionStatsInput<B>;
 
     fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
-        let (predictions, targets) = input.clone().into();
         let [sample_size, _] = input.predictions.dims();
 
-        let (threshold, top_k) = match self.config {
-            PrecisionConfig::Binary { threshold } | PrecisionConfig::Multilabel { threshold } => {
-                (Some(threshold), None)
-            }
-            PrecisionConfig::Multiclass { top_k } => (None, Some(top_k)),
-        };
-
-        let cf_stats =
-            ConfusionStats::new(predictions, targets, threshold, top_k, self.class_reduction);
+        let cf_stats = ConfusionStats::new(input, &self.config);
         let metric =
             self.class_average(cf_stats.clone().true_positive() / cf_stats.predicted_positive());
 
@@ -169,15 +129,10 @@ mod tests {
     use rstest::rstest;
 
     #[rstest]
-    #[case::binary_micro(Micro, THRESHOLD, 0.5)]
-    #[case::binary_macro(Macro, THRESHOLD, 0.5)]
-    fn test_binary_precision(
-        #[case] class_reduction: ClassReduction,
-        #[case] threshold: f64,
-        #[case] expected: f64,
-    ) {
+    #[case::binary_macro(THRESHOLD, 0.5)]
+    fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) {
         let input = dummy_classification_input(&ClassificationType::Binary).into();
-        let mut metric = PrecisionMetric::binary(threshold).with_class_reduction(class_reduction);
+        let mut metric = PrecisionMetric::binary(threshold);
         let _entry = metric.update(&input, &MetricMetadata::fake());
         TensorData::from([metric.value()])
             .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
@@ -194,7 +149,7 @@ mod tests {
         #[case] expected: f64,
     ) {
         let input = dummy_classification_input(&ClassificationType::Multiclass).into();
-        let mut metric = PrecisionMetric::multiclass(top_k).with_class_reduction(class_reduction);
+        let mut metric = PrecisionMetric::multiclass(top_k, class_reduction);
         let _entry = metric.update(&input, &MetricMetadata::fake());
         TensorData::from([metric.value()])
             .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
@@ -203,14 +158,13 @@ mod tests {
     #[rstest]
     #[case::multilabel_micro(Micro, THRESHOLD, 5.0/8.0)]
     #[case::multilabel_macro(Macro, THRESHOLD, (2.0/3.0 + 2.0/3.0 + 0.5)/3.0)]
-    fn test_precision(
+    fn test_multilabel_precision(
         #[case] class_reduction: ClassReduction,
         #[case] threshold: f64,
         #[case] expected: f64,
     ) {
         let input = dummy_classification_input(&ClassificationType::Multilabel).into();
-        let mut metric =
-            PrecisionMetric::multilabel(threshold).with_class_reduction(class_reduction);
+        let mut metric = PrecisionMetric::multilabel(threshold, class_reduction);
         let _entry = metric.update(&input, &MetricMetadata::fake());
         TensorData::from([metric.value()])
             .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs
new file mode 100644
index 0000000000..8ce4351396
--- /dev/null
+++ b/crates/burn-train/src/metric/recall.rs
@@ -0,0 +1,171 @@
+use super::{
+    classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
+    confusion_stats::{ConfusionStats, ConfusionStatsInput},
+    state::{FormatOptions, NumericMetricState},
+    Metric, MetricEntry, MetricMetadata, Numeric,
+};
+use burn_core::{
+    prelude::{Backend, Tensor},
+    tensor::cast::ToElement,
+};
+use core::marker::PhantomData;
+use std::num::NonZeroUsize;
+
+///The Precision Metric
+#[derive(Default)]
+pub struct RecallMetric<B: Backend> {
+    state: NumericMetricState,
+    _b: PhantomData<B>,
+    config: ClassificationMetricConfig,
+}
+
+impl<B: Backend> RecallMetric<B> {
+    /// Recall metric for binary classification.
+    ///
+    /// # Arguments
+    ///
+    /// * `threshold` - The threshold to transform a probability into a binary prediction.
+    #[allow(dead_code)]
+    pub fn binary(threshold: f64) -> Self {
+        Self {
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::Threshold(threshold),
+                // binary classification results are the same independently of class_reduction
+                ..Default::default()
+            },
+            ..Default::default()
+        }
+    }
+
+    /// Recall metric for multiclass classification.
+    ///
+    /// # Arguments
+    ///
+    /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
+    #[allow(dead_code)]
+    pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
+        Self {
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::TopK(
+                    NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
+                ),
+                class_reduction,
+            },
+            ..Default::default()
+        }
+    }
+
+    /// Recall metric for multi-label classification.
+    ///
+    /// # Arguments
+    ///
+    /// * `threshold` - The threshold to transform a probability into a binary prediction.
+    #[allow(dead_code)]
+    pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
+        Self {
+            config: ClassificationMetricConfig {
+                decision_rule: DecisionRule::Threshold(threshold),
+                class_reduction,
+            },
+            ..Default::default()
+        }
+    }
+
+    fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
+        use ClassReduction::{Macro, Micro};
+        let avg_tensor = match self.config.class_reduction {
+            Micro => aggregated_metric,
+            Macro => {
+                if aggregated_metric.contains_nan().any().into_scalar() {
+                    let nan_mask = aggregated_metric.is_nan();
+                    aggregated_metric = aggregated_metric
+                        .clone()
+                        .select(0, nan_mask.bool_not().argwhere().squeeze(1))
+                }
+                aggregated_metric.mean()
+            }
+        };
+        avg_tensor.into_scalar().to_f64()
+    }
+}
+
+impl<B: Backend> Metric for RecallMetric<B> {
+    const NAME: &'static str = "Recall";
+    type Input = ConfusionStatsInput<B>;
+
+    fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
+        let [sample_size, _] = input.predictions.dims();
+
+        let cf_stats = ConfusionStats::new(input, &self.config);
+        let metric = self.class_average(cf_stats.clone().true_positive() / cf_stats.positive());
+
+        self.state.update(
+            100.0 * metric,
+            sample_size,
+            FormatOptions::new(Self::NAME).unit("%").precision(2),
+        )
+    }
+
+    fn clear(&mut self) {
+        self.state.reset()
+    }
+}
+
+impl<B: Backend> Numeric for RecallMetric<B> {
+    fn value(&self) -> f64 {
+        self.state.value()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::{
+        ClassReduction::{self, *},
+        Metric, MetricMetadata, Numeric, RecallMetric,
+    };
+    use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD};
+    use burn_core::tensor::TensorData;
+    use rstest::rstest;
+
+    #[rstest]
+    #[case::binary_macro(THRESHOLD, 0.5)]
+    fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) {
+        let input = dummy_classification_input(&ClassificationType::Binary).into();
+        let mut metric = RecallMetric::binary(threshold);
+        let _entry = metric.update(&input, &MetricMetadata::fake());
+        TensorData::from([metric.value()])
+            .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
+    }
+
+    #[rstest]
+    #[case::multiclass_micro_k1(Micro, 1, 3.0/5.0)]
+    #[case::multiclass_micro_k2(Micro, 2, 4.0/5.0)]
+    #[case::multiclass_macro_k1(Macro, 1, (0.5 + 1.0 + 0.5)/3.0)]
+    #[case::multiclass_macro_k2(Macro, 2, (1.0 + 1.0 + 0.5)/3.0)]
+    fn test_multiclass_recall(
+        #[case] class_reduction: ClassReduction,
+        #[case] top_k: usize,
+        #[case] expected: f64,
+    ) {
+        let input = dummy_classification_input(&ClassificationType::Multiclass).into();
+        let mut metric = RecallMetric::multiclass(top_k, class_reduction);
+        let _entry = metric.update(&input, &MetricMetadata::fake());
+        TensorData::from([metric.value()])
+            .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
+    }
+
+    #[rstest]
+    #[case::multilabel_micro(Micro, THRESHOLD, 5.0/9.0)]
+    #[case::multilabel_macro(Macro, THRESHOLD, (0.5 + 1.0 + 1.0/3.0)/3.0)]
+    fn test_multilabel_recall(
+        #[case] class_reduction: ClassReduction,
+        #[case] threshold: f64,
+        #[case] expected: f64,
+    ) {
+        let input = dummy_classification_input(&ClassificationType::Multilabel).into();
+        let mut metric = RecallMetric::multilabel(threshold, class_reduction);
+        let _entry = metric.update(&input, &MetricMetadata::fake());
+        TensorData::from([metric.value()])
+            .assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
+    }
+}