From 4182db7a8959f77abe19af4e4e2e4a32f31e056a Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Wed, 29 May 2024 22:06:51 +0200 Subject: [PATCH] Add 3d unet engine --- direct/nn/unet/unet_engine.py | 89 +++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/direct/nn/unet/unet_engine.py b/direct/nn/unet/unet_engine.py index 3e9840e9..3dcaca52 100644 --- a/direct/nn/unet/unet_engine.py +++ b/direct/nn/unet/unet_engine.py @@ -299,3 +299,92 @@ def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]: output_kspace = None return output_image, output_kspace + + +class Unet3dEngine(MRIModelEngine): + """Unet3d Model Engine. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`Unet3dEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + self._spatial_dims = (3, 4) + + def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]: + """Forward function for :class:`Unet3dEngine`. + + Parameters + ---------- + data : dict[str, Any] + Input data dictionary containing the following keys: "masked_kspace" and "sensitivity_map" + if image initialization is "sense". + + Returns + ------- + tuple[torch.Tensor, None] + Prediction of image and None for k-space. + """ + + sensitity_map = ( + data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None # type: ignore + ) + + output_image = self.model(masked_kspace=data["masked_kspace"], sensitivity_map=sensitity_map) + output_image = T.modulus(output_image) + + output_kspace = None + + return output_image, output_kspace