From 0a25a6acde88bb69eb6e6ff71f51102187700524 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Wed, 10 Apr 2024 01:40:02 +0200 Subject: [PATCH] Minor changes in documentation docstrings (#276) --- direct/functionals/__init__.py | 7 ++-- direct/functionals/hfen.py | 12 +++---- direct/functionals/nmae.py | 9 +++--- direct/functionals/nmse.py | 18 ++++++----- direct/functionals/psnr.py | 39 +++++++++++++++++++--- direct/functionals/snr.py | 28 ++++++++++------ direct/functionals/ssim.py | 59 +++++++++++++++++++++++++++++----- direct/nn/vsharp/vsharp.py | 11 +++++-- docs/conf.py | 3 +- 9 files changed, 138 insertions(+), 48 deletions(-) diff --git a/direct/functionals/__init__.py b/direct/functionals/__init__.py index f87d9b16b..0736e05b4 100644 --- a/direct/functionals/__init__.py +++ b/direct/functionals/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) DIRECT Contributors -"""direct.nn.functionals module.""" +"""direct.nn.functionals module. + +This module contains functionals for the direct package as well as the loss +functions needed for training models.""" __all__ = [ "HFENL1Loss", @@ -24,7 +27,7 @@ "fastmri_ssim", "hfen_l1", "hfen_l2", - "snr", + "snr_metric", ] from direct.functionals.challenges import * diff --git a/direct/functionals/hfen.py b/direct/functionals/hfen.py index d96099fd5..feb62fa03 100644 --- a/direct/functionals/hfen.py +++ b/direct/functionals/hfen.py @@ -89,7 +89,7 @@ def _compute_padding(kernel_size: int | list[int] = 5) -> int | tuple[int, ...]: class HFENLoss(nn.Module): - """High Frequency Error Norm (HFEN) Loss as defined in _[1]. + r"""High Frequency Error Norm (HFEN) Loss as defined in _[1]. Calculates: @@ -118,7 +118,7 @@ def __init__( kernel_size: int | list[int] = 5, sigma: float | list[float] = 2.5, norm: bool = False, - ): + ) -> None: """Inits :class:`HFENLoss`. Parameters @@ -188,7 +188,7 @@ def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class HFENL1Loss(HFENLoss): - """High Frequency Error Norm (HFEN) Loss using L1Loss criterion. + r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion. Calculates: @@ -207,7 +207,7 @@ def __init__( kernel_size: int | list[int] = 15, sigma: float | list[float] = 2.5, norm: bool = False, - ): + ) -> None: """Inits :class:`HFENL1Loss`. Parameters @@ -225,7 +225,7 @@ def __init__( class HFENL2Loss(HFENLoss): - """High Frequency Error Norm (HFEN) Loss using L1Loss criterion. + r"""High Frequency Error Norm (HFEN) Loss using L1Loss criterion. Calculates: @@ -244,7 +244,7 @@ def __init__( kernel_size: int | list[int] = 15, sigma: float | list[float] = 2.5, norm: bool = False, - ): + ) -> None: """Inits :class:`HFENL2Loss`. Parameters diff --git a/direct/functionals/nmae.py b/direct/functionals/nmae.py index 7e8a3ce96..1f746e6b9 100644 --- a/direct/functionals/nmae.py +++ b/direct/functionals/nmae.py @@ -8,16 +8,17 @@ class NMAELoss(nn.Module): - """Computes the Normalized Mean Absolute Error (NMAE), i.e.: + r"""Computes the Normalized Mean Absolute Error (NMAE), i.e.: .. math:: + \frac{||u - v||_1}{||u||_1}, where :math:`u` and :math:`v` denote the target and the input. """ - def __init__(self, reduction="mean"): - """Inits :class:`NMAE` + def __init__(self, reduction="mean") -> None: + """Inits :class:`NMAELoss` Parameters ---------- @@ -29,7 +30,7 @@ def __init__(self, reduction="mean"): self.mae_loss = nn.L1Loss(reduction=reduction) def forward(self, input: torch.Tensor, target: torch.Tensor): - """Forward method of :class:`NMAE`. + """Forward method of :class:`NMAELoss`. Parameters ---------- diff --git a/direct/functionals/nmse.py b/direct/functionals/nmse.py index d36acb380..36037c845 100644 --- a/direct/functionals/nmse.py +++ b/direct/functionals/nmse.py @@ -8,16 +8,17 @@ class NMSELoss(nn.Module): - """Computes the Normalized Mean Squared Error (NMSE), i.e.: + r"""Computes the Normalized Mean Squared Error (NMSE), i.e.: .. math:: + \frac{||u - v||_2^2}{||u||_2^2}, where :math:`u` and :math:`v` denote the target and the input. """ - def __init__(self, reduction="mean"): - """Inits :class:`NMSE` + def __init__(self, reduction="mean") -> None: + """Inits :class:`NMSELoss` Parameters ---------- @@ -29,7 +30,7 @@ def __init__(self, reduction="mean"): self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, input: torch.Tensor, target: torch.Tensor): - """Forward method of :class:`NMSE`. + """Forward method of :class:`NMSELoss`. Parameters ---------- @@ -44,16 +45,17 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): class NRMSELoss(nn.Module): - """Computes the Normalized Root Mean Squared Error (NRMSE), i.e.: + r"""Computes the Normalized Root Mean Squared Error (NRMSE), i.e.: .. math:: + \frac{||u - v||_2}{||u||_2}, where :math:`u` and :math:`v` denote the target and the input. """ - def __init__(self, reduction="mean"): - """Inits :class:`NRMSE` + def __init__(self, reduction="mean") -> None: + """Inits :class:`NRMSELos` Parameters ---------- @@ -65,7 +67,7 @@ def __init__(self, reduction="mean"): self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, input: torch.Tensor, target: torch.Tensor): - """Forward method of :class:`NRMSE`. + """Forward method of :class:`NRMSELoss`. Parameters ---------- diff --git a/direct/functionals/psnr.py b/direct/functionals/psnr.py index 93f199e64..64a9a631b 100644 --- a/direct/functionals/psnr.py +++ b/direct/functionals/psnr.py @@ -1,12 +1,14 @@ # coding=utf-8 -# Copyright (c) DIRECT Contributors + +"""Peak signal-to-noise ratio (pSNR) metric for the direct package.""" + import torch import torch.nn as nn __all__ = ("batch_psnr", "PSNRLoss") -def batch_psnr(input_data, target_data, reduction="mean"): +def batch_psnr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor: """This function is a torch implementation of skimage.metrics.compare_psnr. Parameters @@ -37,10 +39,37 @@ def batch_psnr(input_data, target_data, reduction="mean"): class PSNRLoss(nn.Module): - __constants__ = ["reduction"] + """Peak signal-to-noise ratio loss function PyTorch implementation. - def __init__(self, reduction="mean"): + Parameters + ---------- + reduction : str + Batch reduction. Default: "mean". + """ + + def __init__(self, reduction: str = "mean") -> None: + """Inits :class:`PSNRLoss`. + + Parameters + ---------- + reduction : str + Batch reduction. Default: "mean". + """ + super().__init__() self.reduction = reduction - def forward(self, input_data, target_data): + def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch.Tensor: + """Performs forward pass of :class:`PSNRLoss`. + + Parameters + ---------- + input_data : torch.Tensor + Input 2D data. + target_data : torch.Tensor + Target 2D data. + + Returns + ------- + torch.Tensor + """ return batch_psnr(input_data, target_data, reduction=self.reduction) diff --git a/direct/functionals/snr.py b/direct/functionals/snr.py index fc142b9c4..280532f06 100644 --- a/direct/functionals/snr.py +++ b/direct/functionals/snr.py @@ -1,22 +1,27 @@ # Copyright (c) DIRECT Contributors -"""direct.nn.functionals.snr module.""" +"""Signal-to-noise ratio (SNR) metric for the direct package.""" import torch from torch import nn -__all__ = ("snr", "SNRLoss") +__all__ = ("snr_metric", "SNRLoss") -def snr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor: +def snr_metric(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "mean") -> torch.Tensor: """This function is a torch implementation of SNR metric for batches. - ..math:: - SNR = 10 \cdot \log_{10}\left(\frac{\sum \text{square_error}}{\sum \text{square_error_noise}}\right) + .. math:: + + SNR = 10 \\cdot \\log_{10}\\left(\\frac{\\text{square_error}}{\\text{square_error_noise}}\\right) + + + where: + + - :math:`\\text{square_error}` is the sum of squared values of the clean (target) data. + - :math:`\\text{square_error_noise}` is the sum of squared differences between the input data and + the clean (target) data. - Where: - - \text{square_error} is the sum of squared values of the clean (target) data. - - \text{square_error_noise} is the sum of squared differences between the input data and the clean (target) data. If reduction is "mean", the function returns the mean SNR value. If reduction is "sum", the function returns the sum of SNR values. @@ -33,6 +38,7 @@ def snr(input_data: torch.Tensor, target_data: torch.Tensor, reduction: str = "m ------- torch.Tensor """ + batch_size = target_data.size(0) input_view = input_data.view(batch_size, -1) target_view = target_data.view(batch_size, -1) @@ -59,7 +65,7 @@ def __init__(self, reduction: str = "mean") -> None: Parameters ---------- reduction : str - Batch reduction. Default: str. + Batch reduction. Default: "mean". """ super().__init__() self.reduction = reduction @@ -70,10 +76,12 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor) -> torch. Parameters ---------- input_data : torch.Tensor + Input 2D data. target_data : torch.Tensor + Target 2D data. Returns ------- torch.Tensor """ - return snr(input_data, target_data, reduction=self.reduction) + return snr_metric(input_data, target_data, reduction=self.reduction) diff --git a/direct/functionals/ssim.py b/direct/functionals/ssim.py index 1de051285..9cb752a7a 100644 --- a/direct/functionals/ssim.py +++ b/direct/functionals/ssim.py @@ -1,6 +1,6 @@ # Copyright (c) DIRECT Contributors -"""direct.nn.functionals.ssim module.""" +"""This module contains SSIM loss functions for the direct package.""" # Taken from: https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py @@ -19,13 +19,27 @@ class SSIMLoss(nn.Module): - """SSIM loss module. + """SSIM loss module as implemented in [1]_. + + Parameters + ---------- + win_size: int + Window size for SSIM calculation. Default: 7. + k1: float + k1 parameter for SSIM calculation. Default: 0.1. + k2: float + k2 parameter for SSIM calculation. Default: 0.03. + + References + ---------- + + .. [1] https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py - From: https://github.com/facebookresearch/fastMRI/blob/master/fastmri/losses.py """ - def __init__(self, win_size=7, k1=0.01, k2=0.03): - """ + def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None: + """Inits :class:`SSIMLoss`. + Parameters ---------- win_size: int @@ -43,6 +57,21 @@ def __init__(self, win_size=7, k1=0.01, k2=0.03): self.cov_norm = NP / (NP - 1) def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_range: torch.Tensor) -> torch.Tensor: + """Forward pass of :class:`SSIMloss`. + + Parameters + ---------- + input_data : torch.Tensor + 2D Input data. + target_data : torch.Tensor + 2D Target data. + data_range : torch.Tensor + Data range. + + Returns + ------- + torch.Tensor + """ data_range = data_range[:, None, None, None] C1 = (self.k1 * data_range) ** 2 C2 = (self.k2 * data_range) ** 2 @@ -67,10 +96,21 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang class SSIM3DLoss(nn.Module): - """SSIM loss module for 3D data.""" + """SSIM loss module for 3D data. + + Parameters + ---------- + win_size: int + Window size for SSIM calculation. Default: 7. + k1: float + k1 parameter for SSIM calculation. Default: 0.1. + k2: float + k2 parameter for SSIM calculation. Default: 0.03. + """ + + def __init__(self, win_size=7, k1=0.01, k2=0.03) -> None: + """Inits :class:`SSIM3DLoss`. - def __init__(self, win_size=7, k1=0.01, k2=0.03): - """ Parameters ---------- win_size: int @@ -90,8 +130,11 @@ def forward(self, input_data: torch.Tensor, target_data: torch.Tensor, data_rang Parameters ---------- input_data : torch.Tensor + 3D Input data. target_data : torch.Tensor + 3D Target data. data_range : torch.Tensor + Data range. Returns ------- diff --git a/direct/nn/vsharp/vsharp.py b/direct/nn/vsharp/vsharp.py index 6108bff13..fb417f99a 100644 --- a/direct/nn/vsharp/vsharp.py +++ b/direct/nn/vsharp/vsharp.py @@ -1,12 +1,17 @@ # Copyright (c) DIRECT Contributors -"""This module provides the implementation of the variable Splitting Half-quadratic ADMM algorithm for Reconstruction - of inverse-Problems (vSHARPP) model as presented in [1]_. +"""This module provides the implementation of vSHARP model. + +Most specifically, vSHARP is the variable Splitting Half-quadratic ADMM algorithm for Reconstruction +of inverse-Problems (vSHARPP) model as presented in [1]_. + References ---------- + .. [1] George Yiasemis et. al. vSHARP: variable Splitting Half-quadratic ADMM algorithm for Reconstruction -of inverse-Problems (2023). https://arxiv.org/abs/2309.09954. + of inverse-Problems (2023). https://arxiv.org/abs/2309.09954. + """ diff --git a/docs/conf.py b/docs/conf.py index feb1f000d..ead2a72a3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -158,7 +158,6 @@ "use_issues_button": True, "use_edit_page_button": True, "use_download_button": False, - "single_page": False, "use_fullscreen_button": False, "home_page_in_toc": True, } @@ -175,7 +174,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["static"] +html_static_path = [] # -- Options for HTMLHelp output ---------------------------------------