Skip to content

Commit

Permalink
3d egines and models
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed May 29, 2024
1 parent 1e69351 commit 6023b88
Show file tree
Hide file tree
Showing 10 changed files with 968 additions and 1 deletion.
10 changes: 10 additions & 0 deletions direct/nn/unet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ class Unet2dConfig(ModelConfig):
image_initialization: InitType = InitType.ZERO_FILLED


@dataclass
class Unet3dConfig(ModelConfig):
num_filters: int = 16
num_pool_layers: int = 4
dropout_probability: float = 0.0
skip_connection: bool = False
normalized: bool = False
image_initialization: InitType = InitType.ZERO_FILLED


@dataclass
class UnetModel3dConfig(ModelConfig):
in_channels: int = 2
Expand Down
140 changes: 140 additions & 0 deletions direct/nn/unet/unet_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
from __future__ import annotations

import math
from typing import Callable, Optional

import torch
from torch import nn
from torch.nn import functional as F

from direct.data import transforms as T
from direct.nn.types import InitType


class ConvBlock3D(nn.Module):
"""3D U-Net convolutional block."""
Expand Down Expand Up @@ -403,6 +407,142 @@ def forward(self, input_data: torch.Tensor) -> torch.Tensor:
return output


class Unet3d(nn.Module):
"""PyTorch implementation of a U-Net model for MRI Reconstruction in 3D."""

def __init__(
self,
forward_operator: Callable,
backward_operator: Callable,
num_filters: int,
num_pool_layers: int,
dropout_probability: float,
skip_connection: bool = False,
normalized: bool = False,
image_initialization: InitType = InitType.ZERO_FILLED,
**kwargs,
):
"""Inits :class:`Unet3d`.
Parameters
----------
forward_operator: Callable
Forward Operator.
backward_operator: Callable
Backward Operator.
num_filters: int
Number of first layer filters.
num_pool_layers: int
Number of pooling layers.
dropout_probability: float
Dropout probability.
skip_connection: bool
If True, skip connection is used for the output. Default: False.
normalized: bool
If True, Normalized Unet is used. Default: False.
image_initialization: InitType
Type of image initialization. Default: InitType.ZERO_FILLED.
kwargs: dict
"""
super().__init__()
extra_keys = kwargs.keys()
for extra_key in extra_keys:
if extra_key not in [
"sensitivity_map_model",
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")
self.unet: nn.Module
if normalized:
self.unet = NormUnetModel3d(
in_channels=2,
out_channels=2,
num_filters=num_filters,
num_pool_layers=num_pool_layers,
dropout_probability=dropout_probability,
)
else:
self.unet = UnetModel3d(
in_channels=2,
out_channels=2,
num_filters=num_filters,
num_pool_layers=num_pool_layers,
dropout_probability=dropout_probability,
)
self.forward_operator = forward_operator
self.backward_operator = backward_operator
self.skip_connection = skip_connection
self.image_initialization = image_initialization
self._coil_dim = 1
self._spatial_dims = (3, 4)

def compute_sense_init(self, kspace: torch.Tensor, sensitivity_map: torch.Tensor) -> torch.Tensor:
r"""Computes sense initialization :math:`x_{\text{SENSE}}`:
.. math::
x_{\text{SENSE}} = \sum_{k=1}^{n_c} {S^{k}}^* \times y^k
where :math:`y^k` denotes the data from coil :math:`k`.
Parameters
----------
kspace: torch.Tensor
k-space of shape (N, coil, slice/time, height, width, complex=2).
sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2).
Returns
-------
input_image: torch.Tensor
Sense initialization :math:`x_{\text{SENSE}}`.
"""
input_image = T.complex_multiplication(
T.conjugate(sensitivity_map),
self.backward_operator(kspace, dim=self._spatial_dims),
)
input_image = input_image.sum(self._coil_dim)
return input_image

def forward(
self,
masked_kspace: torch.Tensor,
sensitivity_map: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Computes forward pass of Unet2d.
Parameters
----------
masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, slice/time, height, width, complex=2).
sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2). Default: None.
Returns
-------
output: torch.Tensor
Output image of shape (N, slice/time, height, width, complex=2).
"""
if self.image_initialization == InitType.SENSE:
if sensitivity_map is None:
raise ValueError("Expected sensitivity_map not to be None with InitType.SENSE image_initialization.")
input_image = self.compute_sense_init(
kspace=masked_kspace,
sensitivity_map=sensitivity_map,
)
elif self.image_initialization == InitType.ZERO_FILLED:
input_image = self.backward_operator(masked_kspace, dim=self._spatial_dims).sum(self._coil_dim)
else:
raise ValueError(
f"Unknown image_initialization. Expected InitType.ZERO_FILLED or InitType.SENSE. "
f"Got {self.image_initialization}."
)

output = self.unet(input_image.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
if self.skip_connection:
output += input_image
return output


def pad_to_pow_of_2(inp: torch.Tensor, k: int) -> tuple[torch.Tensor, list[int]]:
"""Pads the input tensor along the spatial dimensions (depth, height, width) to the nearest power of 2.
Expand Down
8 changes: 8 additions & 0 deletions direct/nn/varnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,11 @@ class EndToEndVarNetConfig(ModelConfig):
regularizer_num_filters: int = 18
regularizer_num_pull_layers: int = 4
regularizer_dropout: float = 0.0


@dataclass
class EndToEndVarNet3DConfig(ModelConfig):
num_layers: int = 8
regularizer_num_filters: int = 18
regularizer_num_pull_layers: int = 4
regularizer_dropout: float = 0.0
171 changes: 171 additions & 0 deletions direct/nn/varnet/varnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from direct.data.transforms import expand_operator, reduce_operator
from direct.nn.unet import UnetModel2d
from direct.nn.unet.unet_3d import UnetModel3d


class EndToEndVarNet(nn.Module):
Expand Down Expand Up @@ -178,3 +179,173 @@ def forward(
dim=self._complex_dim,
)
return current_kspace - self.learning_rate * kspace_error + regularization_term


class EndToEndVarNet3D(nn.Module):
"""End-to-End Variational Network based on [1]_ extended to 3D.
References
----------
.. [1] Sriram, Anuroop, et al. “End-to-End Variational Networks for Accelerated MRI Reconstruction.”
ArXiv:2004.06688 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.06688.
"""

def __init__(
self,
forward_operator: Callable,
backward_operator: Callable,
num_layers: int,
regularizer_num_filters: int = 18,
regularizer_num_pull_layers: int = 4,
regularizer_dropout: float = 0.0,
in_channels: int = 2,
**kwargs,
):
"""Inits :class:`EndToEndVarNet`.
Parameters
----------
forward_operator: Callable
Forward Operator.
backward_operator: Callable
Backward Operator.
num_layers: int
Number of cascades.
regularizer_num_filters: int
Regularizer model number of filters.
regularizer_num_pull_layers: int
Regularizer model number of pulling layers.
regularizer_dropout: float
Regularizer model dropout probability.
"""
super().__init__()
extra_keys = kwargs.keys()
for extra_key in extra_keys:
if extra_key not in [
"model_name",
]:
raise ValueError(f"{type(self).__name__} got key `{extra_key}` which is not supported.")

self.layers_list = nn.ModuleList()

for _ in range(num_layers):
self.layers_list.append(
EndToEndVarNet3DBlock(
forward_operator=forward_operator,
backward_operator=backward_operator,
regularizer_model=UnetModel3d(
in_channels=in_channels,
out_channels=in_channels,
num_filters=regularizer_num_filters,
num_pool_layers=regularizer_num_pull_layers,
dropout_probability=regularizer_dropout,
),
)
)

def forward(
self, masked_kspace: torch.Tensor, sampling_mask: torch.Tensor, sensitivity_map: torch.Tensor
) -> torch.Tensor:
"""Performs the forward pass of :class:`EndToEndVarNet`.
Parameters
----------
masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, slice/time, height, width, complex=2).
sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1).
sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2).
Returns
-------
kspace_prediction: torch.Tensor
K-space prediction of shape (N, coil, slice/time, height, width, complex=2).
"""

kspace_prediction = masked_kspace.clone()
for layer in self.layers_list:
kspace_prediction = layer(kspace_prediction, masked_kspace, sampling_mask, sensitivity_map)
return kspace_prediction


class EndToEndVarNet3DBlock(nn.Module):
"""End-to-End Variational Network 3D block."""

def __init__(
self,
forward_operator: Callable,
backward_operator: Callable,
regularizer_model: nn.Module,
):
"""Inits :class:`EndToEndVarNet3DBlock`.
Parameters
----------
forward_operator: Callable
Forward Operator.
backward_operator: Callable
Backward Operator.
regularizer_model: nn.Module
Regularizer model.
"""
super().__init__()
self.regularizer_model = regularizer_model
self.forward_operator = forward_operator
self.backward_operator = backward_operator
self.learning_rate = nn.Parameter(torch.tensor([1.0]))
self._coil_dim = 1
self._complex_dim = -1
self._spatial_dims = (3, 4)

def forward(
self,
current_kspace: torch.Tensor,
masked_kspace: torch.Tensor,
sampling_mask: torch.Tensor,
sensitivity_map: torch.Tensor,
) -> torch.Tensor:
"""Performs the forward pass of :class:`EndToEndVarNetBlock`.
Parameters
----------
current_kspace: torch.Tensor
Current k-space prediction of shape (N, coil, slice/time, height, width, complex=2).
masked_kspace: torch.Tensor
Masked k-space of shape (N, coil, slice/time, height, width, complex=2).
sampling_mask: torch.Tensor
Sampling mask of shape (N, 1, 1 or slice/time, height, width, 1).
sensitivity_map: torch.Tensor
Sensitivity map of shape (N, coil, slice/time, height, width, complex=2).
Returns
-------
torch.Tensor
Next k-space prediction of shape (N, coil, slice/time, height, width, complex=2).
"""
kspace_error = torch.where(
sampling_mask == 0,
torch.tensor([0.0], dtype=masked_kspace.dtype).to(masked_kspace.device),
current_kspace - masked_kspace,
)
regularization_term = torch.cat(
[
reduce_operator(
self.backward_operator(kspace, dim=self._spatial_dims), sensitivity_map, dim=self._coil_dim
)
for kspace in torch.split(current_kspace, 2, self._complex_dim)
],
dim=self._complex_dim,
).permute(0, 4, 1, 2, 3)
regularization_term = self.regularizer_model(regularization_term).permute(0, 2, 3, 4, 1)
regularization_term = torch.cat(
[
self.forward_operator(
expand_operator(image, sensitivity_map, dim=self._coil_dim), dim=self._spatial_dims
)
for image in torch.split(regularization_term, 2, self._complex_dim)
],
dim=self._complex_dim,
)
return current_kspace - self.learning_rate * kspace_error + regularization_term
Loading

0 comments on commit 6023b88

Please sign in to comment.