Skip to content

Commit

Permalink
Fix math in docstrings compiling issues
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 9, 2024
1 parent fa57f79 commit 2a12f82
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 31 deletions.
12 changes: 6 additions & 6 deletions direct/functionals/hfen.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _compute_padding(kernel_size: int | list[int] = 5) -> int | tuple[int, ...]:


class HFENLoss(nn.Module):
"""High Frequency Error Norm (HFEN) Loss as defined in _[1].
r"""High Frequency Error Norm (HFEN) Loss as defined in _[1].
Calculates:
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
kernel_size: int | list[int] = 5,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENLoss`.
Parameters
Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


class HFENL1Loss(HFENLoss):
"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
Calculates:
Expand All @@ -207,7 +207,7 @@ def __init__(
kernel_size: int | list[int] = 15,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENL1Loss`.
Parameters
Expand All @@ -225,7 +225,7 @@ def __init__(


class HFENL2Loss(HFENLoss):
"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
Calculates:
Expand All @@ -244,7 +244,7 @@ def __init__(
kernel_size: int | list[int] = 15,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENL2Loss`.
Parameters
Expand Down
9 changes: 5 additions & 4 deletions direct/functionals/nmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@


class NMAELoss(nn.Module):
"""Computes the Normalized Mean Absolute Error (NMAE), i.e.:
r"""Computes the Normalized Mean Absolute Error (NMAE), i.e.:
.. math::
\frac{||u - v||_1}{||u||_1},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NMAE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NMAELoss`
Parameters
----------
Expand All @@ -29,7 +30,7 @@ def __init__(self, reduction="mean"):
self.mae_loss = nn.L1Loss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NMAE`.
"""Forward method of :class:`NMAELoss`.
Parameters
----------
Expand Down
18 changes: 10 additions & 8 deletions direct/functionals/nmse.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@


class NMSELoss(nn.Module):
"""Computes the Normalized Mean Squared Error (NMSE), i.e.:
r"""Computes the Normalized Mean Squared Error (NMSE), i.e.:
.. math::
\frac{||u - v||_2^2}{||u||_2^2},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NMSE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NMSELoss`
Parameters
----------
Expand All @@ -29,7 +30,7 @@ def __init__(self, reduction="mean"):
self.mse_loss = nn.MSELoss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NMSE`.
"""Forward method of :class:`NMSELoss`.
Parameters
----------
Expand All @@ -44,16 +45,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):


class NRMSELoss(nn.Module):
"""Computes the Normalized Root Mean Squared Error (NRMSE), i.e.:
r"""Computes the Normalized Root Mean Squared Error (NRMSE), i.e.:
.. math::
\frac{||u - v||_2}{||u||_2},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NRMSE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NRMSELos`
Parameters
----------
Expand All @@ -65,7 +67,7 @@ def __init__(self, reduction="mean"):
self.mse_loss = nn.MSELoss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NRMSE`.
"""Forward method of :class:`NRMSELoss`.
Parameters
----------
Expand Down
39 changes: 34 additions & 5 deletions direct/functionals/psnr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Peak signal-to-noise ratio (pSNR) metric for the direct package."""

import torch
import torch.nn as nn

__all__ = ("batch_psnr", "PSNRLoss")


def batch_psnr(input_data, target_data, reduction="mean"):
def batch_psnr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
"""This function is a torch implementation of skimage.metrics.compare_psnr.
Parameters
Expand Down Expand Up @@ -37,10 +39,37 @@ def batch_psnr(input_data, target_data, reduction="mean"):


class PSNRLoss(nn.Module):
__constants__ = ["reduction"]
"""Peak signal-to-noise ratio loss function PyTorch implementation.
def __init__(self, reduction="mean"):
Parameters
----------
reduction : str
Batch reduction. Default: "mean".
"""

def __init__(self, reduction: str = "mean") -> None:
"""Inits :class:`PSNRLoss`.
Parameters
----------
reduction : str
Batch reduction. Default: "mean".
"""
super().__init__()
self.reduction = reduction

def forward(self, input_data, target_data):
def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.Tensor:
"""Performs forward pass of :class:`PSNRLoss`.
Parameters
----------
input_data : torch.Tensor
Input 2D data.
target_data : torch.Tensor
Target 2D data.
Returns
-------
torch.Tensor
"""
return batch_psnr(input_data, target_data, reduction=self.reduction)
59 changes: 51 additions & 8 deletions direct/functionals/ssim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) DIRECT Contributors

"""direct.nn.functionals.ssim module."""
"""This module contains SSIM loss functions for the direct package."""


# Taken from: https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py
Expand All @@ -19,13 +19,27 @@


class SSIMLoss(nn.Module):
"""SSIM loss module.
"""SSIM loss module as implemented in [1]_.
Parameters
----------
win_size: int
Window size for SSIM calculation. Default: 7.
k1: float
k1 parameter for SSIM calculation. Default: 0.1.
k2: float
k2 parameter for SSIM calculation. Default: 0.03.
References
----------
.. [1] https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py
From: https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py
"""

def __init__(self, win_size=7, k1=0.01, k2=0.03):
"""
def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None:
"""Inits :class:`SSIMLoss`.
Parameters
----------
win_size: int
Expand All @@ -43,6 +57,21 @@ def __init__(self, win_size=7, k1=0.01, k2=0.03):
self.cov_norm = NP / (NP - 1)

def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`SSIMloss`.
Parameters
----------
input_data : torch.Tensor
2D Input data.
target_data : torch.Tensor
2D Target data.
data_range : torch.Tensor
Data range.
Returns
-------
torch.Tensor
"""
data_range = data_range[:, None, None, None]
C1 = (self.k1 * data_range) ** 2
C2 = (self.k2 * data_range) ** 2
Expand All @@ -67,10 +96,21 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang


class SSIM3DLoss(nn.Module):
"""SSIM loss module for 3D data."""
"""SSIM loss module for 3D data.
Parameters
----------
win_size: int
Window size for SSIM calculation. Default: 7.
k1: float
k1 parameter for SSIM calculation. Default: 0.1.
k2: float
k2 parameter for SSIM calculation. Default: 0.03.
"""

def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None:
"""Inits :class:`SSIM3DLoss`.
def __init__(self, win_size=7, k1=0.01, k2=0.03):
"""
Parameters
----------
win_size: int
Expand All @@ -90,8 +130,11 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang
Parameters
----------
input_data : torch.Tensor
3D Input data.
target_data : torch.Tensor
3D Target data.
data_range : torch.Tensor
Data range.
Returns
-------
Expand Down

0 comments on commit 2a12f82

Please sign in to comment.