Skip to content

Commit

Permalink
Minor changes in documentation docstrings (#276)
Browse files Browse the repository at this point in the history
georgeyiasemis authored Apr 9, 2024

Verified

This commit was signed with the committer’s verified signature.
MrHadiSatrio Hadi Satrio
1 parent df99d4e commit 0a25a6a
Showing 9 changed files with 138 additions and 48 deletions.
7 changes: 5 additions & 2 deletions direct/functionals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) DIRECT Contributors

"""direct.nn.functionals module."""
"""direct.nn.functionals module.
This module contains functionals for the direct package as well as the loss
functions needed for training models."""

__all__ = [
"HFENL1Loss",
@@ -24,7 +27,7 @@
"fastmri_ssim",
"hfen_l1",
"hfen_l2",
"snr",
"snr_metric",
]

from direct.functionals.challenges import *
12 changes: 6 additions & 6 deletions direct/functionals/hfen.py
Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ def _compute_padding(kernel_size: int | list[int] = 5) -> int | tuple[int, ...]:


class HFENLoss(nn.Module):
"""High Frequency Error Norm (HFEN) Loss as defined in _[1].
r"""High Frequency Error Norm (HFEN) Loss as defined in _[1].
Calculates:
@@ -118,7 +118,7 @@ def __init__(
kernel_size: int | list[int] = 5,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENLoss`.
Parameters
@@ -188,7 +188,7 @@ def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:


class HFENL1Loss(HFENLoss):
"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
Calculates:
@@ -207,7 +207,7 @@ def __init__(
kernel_size: int | list[int] = 15,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENL1Loss`.
Parameters
@@ -225,7 +225,7 @@ def __init__(


class HFENL2Loss(HFENLoss):
"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion.
Calculates:
@@ -244,7 +244,7 @@ def __init__(
kernel_size: int | list[int] = 15,
sigma: float | list[float] = 2.5,
norm: bool = False,
):
) -> None:
"""Inits :class:`HFENL2Loss`.
Parameters
9 changes: 5 additions & 4 deletions direct/functionals/nmae.py
Original file line number Diff line number Diff line change
@@ -8,16 +8,17 @@


class NMAELoss(nn.Module):
"""Computes the Normalized Mean Absolute Error (NMAE), i.e.:
r"""Computes the Normalized Mean Absolute Error (NMAE), i.e.:
.. math::
\frac{||u - v||_1}{||u||_1},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NMAE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NMAELoss`
Parameters
----------
@@ -29,7 +30,7 @@ def __init__(self, reduction="mean"):
self.mae_loss = nn.L1Loss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NMAE`.
"""Forward method of :class:`NMAELoss`.
Parameters
----------
18 changes: 10 additions & 8 deletions direct/functionals/nmse.py
Original file line number Diff line number Diff line change
@@ -8,16 +8,17 @@


class NMSELoss(nn.Module):
"""Computes the Normalized Mean Squared Error (NMSE), i.e.:
r"""Computes the Normalized Mean Squared Error (NMSE), i.e.:
.. math::
\frac{||u - v||_2^2}{||u||_2^2},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NMSE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NMSELoss`
Parameters
----------
@@ -29,7 +30,7 @@ def __init__(self, reduction="mean"):
self.mse_loss = nn.MSELoss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NMSE`.
"""Forward method of :class:`NMSELoss`.
Parameters
----------
@@ -44,16 +45,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):


class NRMSELoss(nn.Module):
"""Computes the Normalized Root Mean Squared Error (NRMSE), i.e.:
r"""Computes the Normalized Root Mean Squared Error (NRMSE), i.e.:
.. math::
\frac{||u - v||_2}{||u||_2},
where :math:`u` and :math:`v` denote the target and the input.
"""

def __init__(self, reduction="mean"):
"""Inits :class:`NRMSE`
def __init__(self, reduction="mean") -> None:
"""Inits :class:`NRMSELos`
Parameters
----------
@@ -65,7 +67,7 @@ def __init__(self, reduction="mean"):
self.mse_loss = nn.MSELoss(reduction=reduction)

def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward method of :class:`NRMSE`.
"""Forward method of :class:`NRMSELoss`.
Parameters
----------
39 changes: 34 additions & 5 deletions direct/functionals/psnr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Peak signal-to-noise ratio (pSNR) metric for the direct package."""

import torch
import torch.nn as nn

__all__ = ("batch_psnr", "PSNRLoss")


def batch_psnr(input_data, target_data, reduction="mean"):
def batch_psnr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
"""This function is a torch implementation of skimage.metrics.compare_psnr.
Parameters
@@ -37,10 +39,37 @@ def batch_psnr(input_data, target_data, reduction="mean"):


class PSNRLoss(nn.Module):
__constants__ = ["reduction"]
"""Peak signal-to-noise ratio loss function PyTorch implementation.
def __init__(self, reduction="mean"):
Parameters
----------
reduction : str
Batch reduction. Default: "mean".
"""

def __init__(self, reduction: str = "mean") -> None:
"""Inits :class:`PSNRLoss`.
Parameters
----------
reduction : str
Batch reduction. Default: "mean".
"""
super().__init__()
self.reduction = reduction

def forward(self, input_data, target_data):
def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.Tensor:
"""Performs forward pass of :class:`PSNRLoss`.
Parameters
----------
input_data : torch.Tensor
Input 2D data.
target_data : torch.Tensor
Target 2D data.
Returns
-------
torch.Tensor
"""
return batch_psnr(input_data, target_data, reduction=self.reduction)
28 changes: 18 additions & 10 deletions direct/functionals/snr.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
# Copyright (c) DIRECT Contributors

"""direct.nn.functionals.snr module."""
"""Signal-to-noise ratio (SNR) metric for the direct package."""

import torch
from torch import nn

__all__ = ("snr", "SNRLoss")
__all__ = ("snr_metric", "SNRLoss")


def snr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
def snr_metric(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor:
"""This function is a torch implementation of SNR metric for batches.
..math::
SNR = 10 \cdot \log_{10}\left(\frac{\sum \text{square_error}}{\sum \text{square_error_noise}}\right)
.. math::
SNR = 10 \\cdot \\log_{10}\\left(\\frac{\\text{square_error}}{\\text{square_error_noise}}\\right)
where:
- :math:`\\text{square_error}` is the sum of squared values of the clean (target) data.
- :math:`\\text{square_error_noise}` is the sum of squared differences between the input data and
the clean (target) data.
Where:
- \text{square_error} is the sum of squared values of the clean (target) data.
- \text{square_error_noise} is the sum of squared differences between the input data and the clean (target) data.
If reduction is "mean", the function returns the mean SNR value.
If reduction is "sum", the function returns the sum of SNR values.
@@ -33,6 +38,7 @@ def snr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "m
-------
torch.Tensor
"""

batch_size = target_data.size(0)
input_view = input_data.view(batch_size, -1)
target_view = target_data.view(batch_size, -1)
@@ -59,7 +65,7 @@ def __init__(self, reduction: str = "mean") -> None:
Parameters
----------
reduction : str
Batch reduction. Default: str.
Batch reduction. Default: "mean".
"""
super().__init__()
self.reduction = reduction
@@ -70,10 +76,12 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.
Parameters
----------
input_data : torch.Tensor
Input 2D data.
target_data : torch.Tensor
Target 2D data.
Returns
-------
torch.Tensor
"""
return snr(input_data, target_data, reduction=self.reduction)
return snr_metric(input_data, target_data, reduction=self.reduction)
59 changes: 51 additions & 8 deletions direct/functionals/ssim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) DIRECT Contributors

"""direct.nn.functionals.ssim module."""
"""This module contains SSIM loss functions for the direct package."""


# Taken from: https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py
@@ -19,13 +19,27 @@


class SSIMLoss(nn.Module):
"""SSIM loss module.
"""SSIM loss module as implemented in [1]_.
Parameters
----------
win_size: int
Window size for SSIM calculation. Default: 7.
k1: float
k1 parameter for SSIM calculation. Default: 0.1.
k2: float
k2 parameter for SSIM calculation. Default: 0.03.
References
----------
.. [1] https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py
From: https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py
"""

def __init__(self, win_size=7, k1=0.01, k2=0.03):
"""
def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None:
"""Inits :class:`SSIMLoss`.
Parameters
----------
win_size: int
@@ -43,6 +57,21 @@ def __init__(self, win_size=7, k1=0.01, k2=0.03):
self.cov_norm = NP / (NP - 1)

def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`SSIMloss`.
Parameters
----------
input_data : torch.Tensor
2D Input data.
target_data : torch.Tensor
2D Target data.
data_range : torch.Tensor
Data range.
Returns
-------
torch.Tensor
"""
data_range = data_range[:, None, None, None]
C1 = (self.k1 * data_range) ** 2
C2 = (self.k2 * data_range) ** 2
@@ -67,10 +96,21 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang


class SSIM3DLoss(nn.Module):
"""SSIM loss module for 3D data."""
"""SSIM loss module for 3D data.
Parameters
----------
win_size: int
Window size for SSIM calculation. Default: 7.
k1: float
k1 parameter for SSIM calculation. Default: 0.1.
k2: float
k2 parameter for SSIM calculation. Default: 0.03.
"""

def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None:
"""Inits :class:`SSIM3DLoss`.
def __init__(self, win_size=7, k1=0.01, k2=0.03):
"""
Parameters
----------
win_size: int
@@ -90,8 +130,11 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang
Parameters
----------
input_data : torch.Tensor
3D Input data.
target_data : torch.Tensor
3D Target data.
data_range : torch.Tensor
Data range.
Returns
-------
11 changes: 8 additions & 3 deletions direct/nn/vsharp/vsharp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# Copyright (c) DIRECT Contributors

"""This module provides the implementation of the variable Splitting Half-quadratic ADMM algorithm for Reconstruction
of inverse-Problems (vSHARPP) model as presented in [1]_.
"""This module provides the implementation of vSHARP model.
Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction
of inverse-Problems (vSHARPP) model as presented in [1]_.
References
----------
.. [1] George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction
of inverse-Problems (2023). https://arxiv.org/abs/2309.09954.
of inverse-Problems (2023). https://arxiv.org/abs/2309.09954.
"""


3 changes: 1 addition & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -158,7 +158,6 @@
"use_issues_button": True,
"use_edit_page_button": True,
"use_download_button": False,
"single_page": False,
"use_fullscreen_button": False,
"home_page_in_toc": True,
}
@@ -175,7 +174,7 @@
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ["static"]
html_static_path = []

# -- Options for HTMLHelp output ---------------------------------------

0 comments on commit 0a25a6a

Please sign in to comment.