Skip to content

Commit

Permalink
Loss Enum & quality fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Mar 22, 2024
1 parent c49990b commit 0760600
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 16 deletions.
25 changes: 25 additions & 0 deletions direct/functionals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@

"""direct.nn.functionals module."""

__all__ = [
"HFENL1Loss",
"HFENL2Loss",
"HFENLoss",
"NMAELoss",
"NMSELoss",
"NRMSELoss",
"PSNRLoss",
"SNRLoss",
"SSIM3DLoss",
"SSIMLoss",
"SobelGradL1Loss",
"SobelGradL2Loss",
"batch_psnr",
"calgary_campinas_psnr",
"calgary_campinas_ssim",
"calgary_campinas_vif",
"fastmri_nmse",
"fastmri_psnr",
"fastmri_ssim",
"hfen_l1",
"hfen_l2",
"snr",
]

from direct.functionals.challenges import *
from direct.functionals.grad import *
from direct.functionals.hfen import *
Expand Down
33 changes: 17 additions & 16 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import direct.functionals as D
from direct.config import BaseConfig
from direct.engine import DoIterationOutput, Engine
from direct.nn.types import LossFunType
from direct.types import TensorOrNone
from direct.utils import (
communication,
Expand Down Expand Up @@ -584,33 +585,33 @@ def hfen_l2_norm_loss(
loss_dict = {}
for curr_loss in self.cfg.training.loss.losses: # type: ignore
loss_fn = curr_loss.function
if loss_fn in ["l1_loss", "kspace_l1_loss"]:
if loss_fn in [LossFunType.L1_LOSS, LossFunType.KSPACE_L1_LOSS]:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l1_loss)
elif loss_fn in ["l2_loss", "kspace_l2_loss"]:
elif loss_fn in [LossFunType.L2_LOSS, LossFunType.KSPACE_L2_LOSS]:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, l2_loss)
elif loss_fn == "ssim_loss":
elif loss_fn == LossFunType.SSIM_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, ssim_loss)
elif loss_fn == "grad_l1_loss":
elif loss_fn == LossFunType.GRAD_L1_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l1_loss)
elif loss_fn == "grad_l2_loss":
elif loss_fn == LossFunType.GRAD_L2_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, grad_l2_loss)
elif loss_fn in ["nmse_loss", "kspace_nmse_loss"]:
elif loss_fn in [LossFunType.NMSE_LOSS, LossFunType.KSPACE_NMSE_LOSS]:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmse_loss)
elif loss_fn in ["nrmse_loss", "kspace_nrmse_loss"]:
elif loss_fn in [LossFunType.NRMSE_LOSS, LossFunType.KSPACE_NRMSE_LOSS]:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nrmse_loss)
elif loss_fn in ["nmae_loss", "kspace_nmae_loss"]:
elif loss_fn in [LossFunType.NMAE_LOSS, LossFunType.KSPACE_NMAE_LOSS]:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, nmae_loss)
elif loss_fn in ["snr_loss", "psnr_loss"]:
loss_dict[loss_fn] = multiply_function(
curr_loss.multiplier, (snr_loss if loss_fn == "snr_loss" else psnr_loss)
)
elif loss_fn == "hfen_l1_loss":
elif loss_fn == LossFunType.SNR_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, snr_loss)
elif loss_fn == LossFunType.PSNR_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, snr_loss)
elif loss_fn == LossFunType.HFEN_L1_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_loss)
elif loss_fn == "hfen_l2_loss":
elif loss_fn == LossFunType.HFEN_L2_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_loss)
elif loss_fn == "hfen_l1_norm_loss":
elif loss_fn == LossFunType.HFEN_L1_NORM_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l1_norm_loss)
elif loss_fn == "hfen_l2_norm_loss":
elif loss_fn == LossFunType.HFEN_L2_NORM_LOSS:
loss_dict[loss_fn] = multiply_function(curr_loss.multiplier, hfen_l2_norm_loss)
else:
raise ValueError(f"{loss_fn} not permissible.")
Expand Down
22 changes: 22 additions & 0 deletions direct/nn/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,25 @@ class InitType(DirectEnum):
sense = "sense"
zero_filled = "zero_filled"
input_image = "input_image"


class LossFunType(DirectEnum):
L1_LOSS = "l1_loss"
KSPACE_L1_LOSS = "kspace_l1_loss"
L2_LOSS = "l2_loss"
KSPACE_L2_LOSS = "kspace_l2_loss"
SSIM_LOSS = "ssim_loss"
GRAD_L1_LOSS = "grad_l1_loss"
GRAD_L2_LOSS = "grad_l2_loss"
NMSE_LOSS = "nmse_loss"
KSPACE_NMSE_LOSS = "kspace_nmse_loss"
NRMSE_LOSS = "nrmse_loss"
KSPACE_NRMSE_LOSS = "kspace_nrmse_loss"
NMAE_LOSS = "nmae_loss"
KSPACE_NMAE_LOSS = "kspace_nmae_loss"
SNR_LOSS = "snr_loss"
PSNR_LOSS = "psnr_loss"
HFEN_L1_LOSS = "hfen_l1_loss"
HFEN_L2_LOSS = "hfen_l2_loss"
HFEN_L1_NORM_LOSS = "hfen_l1_norm_loss"
HFEN_L2_NORM_LOSS = "hfen_l2_norm_loss"

0 comments on commit 0760600

Please sign in to comment.