Skip to content

Commit

Permalink
Add 3d unet engine
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed May 29, 2024
1 parent 37ae137 commit 4182db7
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions direct/nn/unet/unet_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,92 @@ def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]:
output_kspace = None

return output_image, output_kspace


class Unet3dEngine(MRIModelEngine):
"""Unet3d Model Engine.
Parameters
----------
cfg: BaseConfig
Configuration file.
model: nn.Module
Model.
device: str
Device. Can be "cuda:{idx}" or "cpu".
forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The forward operator. Default: None.
backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The backward operator. Default: None.
mixed_precision: bool
Use mixed precision. Default: False.
**models: nn.Module
Additional models.
"""

def __init__(
self,
cfg: BaseConfig,
model: nn.Module,
device: str,
forward_operator: Optional[Callable] = None,
backward_operator: Optional[Callable] = None,
mixed_precision: bool = False,
**models: nn.Module,
):
"""Inits :class:`Unet3dEngine`.
Parameters
----------
cfg: BaseConfig
Configuration file.
model: nn.Module
Model.
device: str
Device. Can be "cuda:{idx}" or "cpu".
forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The forward operator. Default: None.
backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The backward operator. Default: None.
mixed_precision: bool
Use mixed precision. Default: False.
**models: nn.Module
Additional models.
"""
super().__init__(
cfg,
model,
device,
forward_operator=forward_operator,
backward_operator=backward_operator,
mixed_precision=mixed_precision,
**models,
)

self._spatial_dims = (3, 4)

def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, None]:
"""Forward function for :class:`Unet3dEngine`.
Parameters
----------
data : dict[str, Any]
Input data dictionary containing the following keys: "masked_kspace" and "sensitivity_map"
if image initialization is "sense".
Returns
-------
tuple[torch.Tensor, None]
Prediction of image and None for k-space.
"""

sensitity_map = (
data["sensitivity_map"] if self.cfg.model.image_initialization == "sense" else None # type: ignore
)

output_image = self.model(masked_kspace=data["masked_kspace"], sensitivity_map=sensitity_map)
output_image = T.modulus(output_image)

output_kspace = None

return output_image, output_kspace

0 comments on commit 4182db7

Please sign in to comment.