From 0c1ab44b84cc97960303b69075bf2eb2b80bfa0c Mon Sep 17 00:00:00 2001 From: Nathan Painchaud <23144457+nathanpainchaud@users.noreply.github.com> Date: Wed, 8 Nov 2023 23:03:02 +0100 Subject: [PATCH] Add missing expected tensor shapes in monotonic regularization loss docstrings (#183) --- vital/metrics/train/functional.py | 4 ++-- vital/metrics/train/metric.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vital/metrics/train/functional.py b/vital/metrics/train/functional.py index 1c270479..0ed722ba 100644 --- a/vital/metrics/train/functional.py +++ b/vital/metrics/train/functional.py @@ -121,8 +121,8 @@ def monotonic_regularization_loss(input: Tensor, target: Tensor, delta: float) - (link to the paper: https://arxiv.org/pdf/2004.05485.pdf) Args: - input: Input values to regularize so that they have a monotonic relationship with the `target` values. - target: Values used to determine the target monotonic ordering of the values. + input: (N, [1]), Input values to regularize so that they have a monotonic relationship with the `target` values. + target: (N, [1]), Values used to determine the target monotonic ordering of the values. delta: Hyperparameter that decides the spread of the posterior distribution. Returns: diff --git a/vital/metrics/train/metric.py b/vital/metrics/train/metric.py index 568c5b3d..c3f12188 100644 --- a/vital/metrics/train/metric.py +++ b/vital/metrics/train/metric.py @@ -80,8 +80,9 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: """Actual metric calculation. Args: - input: Input values to regularize so that they have a monotonic relationship with the `target` values. - target: Values used to determine the target monotonic ordering of the values. + input: (N, [1]), Input values to regularize so that they have a monotonic relationship with the `target` + values. + target: (N, [1]), Values used to determine the target monotonic ordering of the values. Returns: (1,), Calculated monotonic regularization loss.