Skip to content

Commit

Permalink
Add vSHARP JSSL engine
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 17, 2024
1 parent 7eaba03 commit 17a12bf
Showing 1 changed file with 213 additions and 1 deletion.
214 changes: 213 additions & 1 deletion direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) DIRECT Contributors

"""Engine for vSHARP 2D model."""
"""Engine for vSHARP 2D and 3D models."""

from __future__ import annotations

Expand All @@ -14,6 +14,7 @@
from direct.data import transforms as T
from direct.engine import DoIterationOutput
from direct.nn.mri_models import MRIModelEngine
from direct.nn.ssl.mri_models import JSSLMRIModelEngine, SSLMRIModelEngine
from direct.types import TensorOrNone
from direct.utils import detach_dict, dict_to_device

Expand Down Expand Up @@ -270,3 +271,214 @@ def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Te
)

return output_images, output_kspace


class VSharpNetJSSLEnging(JSSLMRIModelEngine):
"""JSSL vSHARP Model Engine.
Parameters
----------
cfg: BaseConfig
Configuration file.
model: nn.Module
Model.
device: str
Device. Can be "cuda:{idx}" or "cpu".
forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The forward operator. Default: None.
backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The backward operator. Default: None.
mixed_precision: bool
Use mixed precision. Default: False.
**models: nn.Module
Additional models.
"""

def __init__(
self,
cfg: BaseConfig,
model: nn.Module,
device: str,
forward_operator: Optional[Callable] = None,
backward_operator: Optional[Callable] = None,
mixed_precision: bool = False,
**models: nn.Module,
):
"""Inits :class:`VSharpNetJSSLEnging`.
Parameters
----------
cfg: BaseConfig
Configuration file.
model: nn.Module
Model.
device: str
Device. Can be "cuda:{idx}" or "cpu".
forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The forward operator. Default: None.
backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional
The backward operator. Default: None.
mixed_precision: bool
Use mixed precision. Default: False.
**models: nn.Module
Additional models.
"""
super().__init__(
cfg,
model,
device,
forward_operator=forward_operator,
backward_operator=backward_operator,
mixed_precision=mixed_precision,
**models,
)

def forward_function(self, data: dict[str, Any]) -> None:
"""Forward function for :class:`VSharpNetJSSLEnging`."""
raise NotImplementedError(
"Forward function for JSSL vSHARP is not implemented. `VSharpNetJSSLEnging` "
"implements the `_do_iteration` method itself so the forward function should not be "
"called."
)

def _do_iteration(
self,
data: dict[str, Any],
loss_fns: Optional[dict[str, Callable]] = None,
regularizer_fns: Optional[dict[str, Callable]] = None,
) -> DoIterationOutput:
"""This function implements the `_do_iteration` for the JSSL vSHARP model.
Returns
-------
DoIterationOutput
Output of the iteration.
It assumes different behavior for training and inference. During SSL training, it expects the input data
to contain keys "input_kspace" and "input_sampling_mask", otherwise, it expects the input data to contain
keys "masked_kspace" and "sampling_mask".
Parameters
----------
data : dict[str, Any]
Input data dictionary. The dictionary should contain the following keys:
- "is_ssl_training": Boolean indicating if the sample is for SSL training.
- "input_kspace" if SSL training, otherwise "masked_kspace".
- "input_sampling_mask" if SSL training, otherwise "sampling_mask".
- "target_sampling_mask": Sampling mask for the target k-space if SSL training.
- "sensitivity_map": Sensitivity map.
- "target": Target image.
- "padding": Padding, optionally.
loss_fns : Optional[dict[str, Callable]], optional
Loss functions, optional.
regularizer_fns : Optional[dict[str, Callable]], optional
Regularizer functions, optional.
"""

# loss_fns can be None, e.g. during validation
if loss_fns is None:
loss_fns = {}

if regularizer_fns is None:
regularizer_fns = {}

# Get a boolean indicating if the sample is for SSL training
# This will expect the input data to contain the keys "input_kspace" and "input_sampling_mask" if SSL training
is_ssl_training = data["is_ssl_training"][0]

# Get the k-space and mask which differ if SSL training or supervised training
# The also differ during training and inference for SSL
if is_ssl_training and self.model.training:
kspace, mask = data["input_kspace"], data["input_sampling_mask"]
else:
kspace, mask = data["masked_kspace"], data["sampling_mask"]

# Initialize loss and regularizer dictionaries
loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}
regularizer_dict = {
k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
}

data = dict_to_device(data, self.device)

with autocast(enabled=self.mixed_precision):
data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"])

output_images = self.model(
masked_kspace=kspace,
sampling_mask=mask,
sensitivity_map=data["sensitivity_map"],
)

if self.model.training:
if len(output_images) > 1:
# Initialize auxiliary loss weights with a logarithmic scale if multiple auxiliary steps
auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0])
else:
# Initialize auxiliary loss weights with a single value of 1.0 if single step
auxiliary_loss_weights = torch.ones(1).to(output_images[0])

for i in range(len(output_images)):
# Data consistency
output_kspace = T.apply_padding(
kspace + self._forward_operator(output_images[i], data["sensitivity_map"], ~mask),
padding=data["padding"],
)
if is_ssl_training:
# Project predicted k-space onto target k-space if SSL
output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False)

# Compute k-space loss per auxiliary step
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i]
)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, None, output_kspace, auxiliary_loss_weights[i]
)

# SENSE reconstruction if SSL else modulus if supervised
output_images[i] = T.modulus(
T.reduce_operator(
self.backward_operator(output_kspace, dim=self._spatial_dims),
data["sensitivity_map"],
self._coil_dim,
)
if is_ssl_training
else output_images[i]
)

# Compute image loss per auxiliary step
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, output_images[i], None, auxiliary_loss_weights[i]
)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_images[i], None, auxiliary_loss_weights[i]
)

loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore
self._scaler.scale(loss).backward()

output_image = output_images[-1]
else:
output_kspace = T.apply_padding(
kspace + self._forward_operator(output_images[-1], data["sensitivity_map"], ~mask),
padding=data["padding"],
)
output_image = T.modulus(
T.reduce_operator(
self.backward_operator(output_kspace, dim=self._spatial_dims),
data["sensitivity_map"],
self._coil_dim,
)
)

loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging.
regularizer_dict = detach_dict(regularizer_dict)

return DoIterationOutput(
output_image=output_image,
sensitivity_map=data["sensitivity_map"],
data_dict={**loss_dict, **regularizer_dict},
)

0 comments on commit 17a12bf

Please sign in to comment.