From 0760600b6a7e3f8b840e4c1049c1b45af5780b6c Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Fri, 22 Mar 2024 17:14:18 +0100 Subject: [PATCH] Loss Enum & quality fixes --- direct/functionals/__init__.py | 25 +++++++++++++++++++++++++ direct/nn/mri_models.py | 33 +++++++++++++++++---------------- direct/nn/types.py | 22 ++++++++++++++++++++++ 3 files changed, 64 insertions(+), 16 deletions(-) diff --git a/direct/functionals/__init__.py b/direct/functionals/__init__.py index 948d2319..f87d9b16 100644 --- a/direct/functionals/__init__.py +++ b/direct/functionals/__init__.py @@ -2,6 +2,31 @@ """direct.nn.functionals module.""" +__all__ = [ + "HFENL1Loss", + "HFENL2Loss", + "HFENLoss", + "NMAELoss", + "NMSELoss", + "NRMSELoss", + "PSNRLoss", + "SNRLoss", + "SSIM3DLoss", + "SSIMLoss", + "SobelGradL1Loss", + "SobelGradL2Loss", + "batch_psnr", + "calgary_campinas_psnr", + "calgary_campinas_ssim", + "calgary_campinas_vif", + "fastmri_nmse", + "fastmri_psnr", + "fastmri_ssim", + "hfen_l1", + "hfen_l2", + "snr", +] + from direct.functionals.challenges import * from direct.functionals.grad import * from direct.functionals.hfen import * diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index c506d64a..83d9dc94 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -22,6 +22,7 @@ import direct.functionals as D from direct.config import BaseConfig from direct.engine import DoIterationOutput, Engine +from direct.nn.types import LossFunType from direct.types import TensorOrNone from direct.utils import ( communication, @@ -584,33 +585,33 @@ def hfen_l2_norm_loss( loss_dict = {} for curr_loss in self.cfg.training.loss.losses: # type: ignore loss_fn = curr_loss.function - if loss_fn in ["l1_loss", "kspace_l1_loss"]: + if loss_fn in [LossFunType.L1_LOSS, LossFunType.KSPACE_L1_LOSS]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss) - elif loss_fn in ["l2_loss", "kspace_l2_loss"]: + elif loss_fn in [LossFunType.L2_LOSS, LossFunType.KSPACE_L2_LOSS]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss) - elif loss_fn == "ssim_loss": + elif loss_fn == LossFunType.SSIM_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss) - elif loss_fn == "grad_l1_loss": + elif loss_fn == LossFunType.GRAD_L1_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l1_loss) - elif loss_fn == "grad_l2_loss": + elif loss_fn == LossFunType.GRAD_L2_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l2_loss) - elif loss_fn in ["nmse_loss", "kspace_nmse_loss"]: + elif loss_fn in [LossFunType.NMSE_LOSS, LossFunType.KSPACE_NMSE_LOSS]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmse_loss) - elif loss_fn in ["nrmse_loss", "kspace_nrmse_loss"]: + elif loss_fn in [LossFunType.NRMSE_LOSS, LossFunType.KSPACE_NRMSE_LOSS]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nrmse_loss) - elif loss_fn in ["nmae_loss", "kspace_nmae_loss"]: + elif loss_fn in [LossFunType.NMAE_LOSS, LossFunType.KSPACE_NMAE_LOSS]: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmae_loss) - elif loss_fn in ["snr_loss", "psnr_loss"]: - loss_dict[loss_fn] = multiply_function( - curr_loss.multiplier, (snr_loss if loss_fn == "snr_loss" else psnr_loss) - ) - elif loss_fn == "hfen_l1_loss": + elif loss_fn == LossFunType.SNR_LOSS: + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, snr_loss) + elif loss_fn == LossFunType.PSNR_LOSS: + loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, snr_loss) + elif loss_fn == LossFunType.HFEN_L1_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_loss) - elif loss_fn == "hfen_l2_loss": + elif loss_fn == LossFunType.HFEN_L2_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_loss) - elif loss_fn == "hfen_l1_norm_loss": + elif loss_fn == LossFunType.HFEN_L1_NORM_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_norm_loss) - elif loss_fn == "hfen_l2_norm_loss": + elif loss_fn == LossFunType.HFEN_L2_NORM_LOSS: loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_norm_loss) else: raise ValueError(f"{loss_fn} not permissible.") diff --git a/direct/nn/types.py b/direct/nn/types.py index 8eaf0d90..21976a73 100644 --- a/direct/nn/types.py +++ b/direct/nn/types.py @@ -22,3 +22,25 @@ class InitType(DirectEnum): sense = "sense" zero_filled = "zero_filled" input_image = "input_image" + + +class LossFunType(DirectEnum): + L1_LOSS = "l1_loss" + KSPACE_L1_LOSS = "kspace_l1_loss" + L2_LOSS = "l2_loss" + KSPACE_L2_LOSS = "kspace_l2_loss" + SSIM_LOSS = "ssim_loss" + GRAD_L1_LOSS = "grad_l1_loss" + GRAD_L2_LOSS = "grad_l2_loss" + NMSE_LOSS = "nmse_loss" + KSPACE_NMSE_LOSS = "kspace_nmse_loss" + NRMSE_LOSS = "nrmse_loss" + KSPACE_NRMSE_LOSS = "kspace_nrmse_loss" + NMAE_LOSS = "nmae_loss" + KSPACE_NMAE_LOSS = "kspace_nmae_loss" + SNR_LOSS = "snr_loss" + PSNR_LOSS = "psnr_loss" + HFEN_L1_LOSS = "hfen_l1_loss" + HFEN_L2_LOSS = "hfen_l2_loss" + HFEN_L1_NORM_LOSS = "hfen_l1_norm_loss" + HFEN_L2_NORM_LOSS = "hfen_l2_norm_loss"