Skip to content

Commit

Permalink
Add missing expected tensor shapes in monotonic regularization loss d…
Browse files Browse the repository at this point in the history
…ocstrings (#183)
  • Loading branch information
nathanpainchaud authored Nov 8, 2023
1 parent af24988 commit 0c1ab44
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions vital/metrics/train/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def monotonic_regularization_loss(input: Tensor, target: Tensor, delta: float) -
(link to the paper: https://arxiv.org/pdf/2004.05485.pdf)
Args:
input: Input values to regularize so that they have a monotonic relationship with the `target` values.
target: Values used to determine the target monotonic ordering of the values.
input: (N, [1]), Input values to regularize so that they have a monotonic relationship with the `target` values.
target: (N, [1]), Values used to determine the target monotonic ordering of the values.
delta: Hyperparameter that decides the spread of the posterior distribution.
Returns:
Expand Down
5 changes: 3 additions & 2 deletions vital/metrics/train/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""Actual metric calculation.
Args:
input: Input values to regularize so that they have a monotonic relationship with the `target` values.
target: Values used to determine the target monotonic ordering of the values.
input: (N, [1]), Input values to regularize so that they have a monotonic relationship with the `target`
values.
target: (N, [1]), Values used to determine the target monotonic ordering of the values.
Returns:
(1,), Calculated monotonic regularization loss.
Expand Down

0 comments on commit 0c1ab44

Please sign in to comment.