From 1ff2de9658341cd3b40c37373d56a2c7e167b7c6 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Fri, 17 May 2024 11:49:21 +0200 Subject: [PATCH] JSSL/SSL vsharp engine --- direct/nn/ssl/mri_models.py | 2 +- direct/nn/vsharp/vsharp_engine.py | 443 +++++++++++++++++++++++++- tests/tests_ssl/test_vsharp_engine.py | 194 +++++++++++ 3 files changed, 635 insertions(+), 4 deletions(-) create mode 100644 tests/tests_ssl/test_vsharp_engine.py diff --git a/direct/nn/ssl/mri_models.py b/direct/nn/ssl/mri_models.py index 0fe7caec..b56d2990 100644 --- a/direct/nn/ssl/mri_models.py +++ b/direct/nn/ssl/mri_models.py @@ -19,7 +19,7 @@ from direct.utils import detach_dict, dict_to_device, normalize_image from direct.utils.events import get_event_storage -__all__ = ["SSLMRIModelEngine"] +__all__ = ["SSLMRIModelEngine", "JSSLMRIModelEngine"] class SSLMRIModelEngine(MRIModelEngine): diff --git a/direct/nn/vsharp/vsharp_engine.py b/direct/nn/vsharp/vsharp_engine.py index 0ec3f8d4..ad745f8b 100644 --- a/direct/nn/vsharp/vsharp_engine.py +++ b/direct/nn/vsharp/vsharp_engine.py @@ -1,6 +1,18 @@ # Copyright (c) DIRECT Contributors -"""Engine for vSHARP 2D model.""" +"""Engines for vSHARP 2D and 3D models [1]. + +Includes supervised, self-supervised and joint supervised and self-supervised learning [2] engines. + +References +---------- +.. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and + Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). + https://doi.org/10.48550/arXiv.2311.15856. +.. [2] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and + Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). + https://doi.org/10.48550/arXiv.2311.15856. +""" from __future__ import annotations @@ -14,12 +26,13 @@ from direct.data import transforms as T from direct.engine import DoIterationOutput from direct.nn.mri_models import MRIModelEngine +from direct.nn.ssl.mri_models import JSSLMRIModelEngine, SSLMRIModelEngine from direct.types import TensorOrNone from direct.utils import detach_dict, dict_to_device class VSharpNet3DEngine(MRIModelEngine): - """VSharpNet Engine.""" + """VSharpNet 3D Model Engine.""" def __init__( self, @@ -31,7 +44,7 @@ def __init__( mixed_precision: bool = False, **models: nn.Module, ): - """Inits :class:`VSharpNetEngine`. + """Inits :class:`VSharpNet3DEngine`. Parameters ---------- @@ -270,3 +283,427 @@ def forward_function(self, data: dict[str, Any]) -> tuple[torch.Tensor, torch.Te ) return output_images, output_kspace + + +class VSharpNetSSLEngine(SSLMRIModelEngine): + """Self-supervised Learning vSHARP Model 2D Engine. + + Used for the main experiments for SSL in the JSSL paper [1]. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + + + References + ---------- + .. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and + Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). + https://doi.org/10.48550/arXiv.2311.15856. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VSharpNetSSLEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: dict[str, Any]) -> None: + """Forward function for :class:`VSharpNetSSLEngine`.""" + raise NotImplementedError( + "Forward function for SSL vSHARP engine is not implemented. `VSharpNetSSLEngine` " + "implements the `_do_iteration` method itself so the forward function should not be " + "called." + ) + + def _do_iteration( + self, + data: dict[str, Any], + loss_fns: Optional[dict[str, Callable]] = None, + regularizer_fns: Optional[dict[str, Callable]] = None, + ) -> DoIterationOutput: + """This function implements the `_do_iteration` for the SSL vSHARP model. + + Returns + ------- + DoIterationOutput + Output of the iteration. + + + It assumes different behavior for training and inference. During training, it expects the input data + to contain keys "input_kspace" and "input_sampling_mask", otherwise, it expects the input data to contain + keys "masked_kspace" and "sampling_mask". + + Parameters + ---------- + data : dict[str, Any] + Input data dictionary. The dictionary should contain the following keys: + - "input_kspace" if training, otherwise "masked_kspace". + - "input_sampling_mask" if training, otherwise "sampling_mask". + - "target_sampling_mask": Sampling mask for the target k-space if training. + - "sensitivity_map": Sensitivity map. + - "target": Target image. + - "padding": Padding, optionally. + loss_fns : Optional[dict[str, Callable]], optional + Loss functions, optional. + regularizer_fns : Optional[dict[str, Callable]], optional + Regularizer functions, optional. + + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + # Move data to device + data = dict_to_device(data, self.device) + + # Get the k-space and mask which differ during training and inference for SSL + if self.model.training: + kspace, mask = data["input_kspace"], data["input_sampling_mask"] + else: + kspace, mask = data["masked_kspace"], data["sampling_mask"] + + # Initialize loss and regularizer dictionaries + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + with autocast(enabled=self.mixed_precision): + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=kspace, + sampling_mask=mask, + sensitivity_map=data["sensitivity_map"], + ) + + if self.model.training: + if len(output_images) > 1: + # Initialize auxiliary loss weights with a logarithmic scale if multiple auxiliary steps + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + else: + # Initialize auxiliary loss weights with a single value of 1.0 if single step + auxiliary_loss_weights = torch.ones(1).to(output_images[0]) + + for i in range(len(output_images)): + # Data consistency + output_kspace = T.apply_padding( + kspace + self._forward_operator(output_images[i], data["sensitivity_map"], ~mask), + padding=data.get("padding", None), + ) + # Project predicted k-space onto target k-space if SSL + output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False) + + # Compute k-space loss per auxiliary step + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + + # SENSE reconstruction + output_images[i] = T.modulus( + T.reduce_operator( + self.backward_operator(output_kspace, dim=self._spatial_dims), + data["sensitivity_map"], + self._coil_dim, + ) + ) + + # Compute image loss per auxiliary step + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore + self._scaler.scale(loss).backward() + + output_image = output_images[-1] + else: + output_kspace = T.apply_padding( + kspace + self._forward_operator(output_images[-1], data["sensitivity_map"], ~mask), + padding=data.get("padding", None), + ) + # SENSE reconstruction using data consistent k-space + output_image = T.modulus( + T.reduce_operator( + self.backward_operator(output_kspace, dim=self._spatial_dims), + data["sensitivity_map"], + self._coil_dim, + ) + ) + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + regularizer_dict = detach_dict(regularizer_dict) + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) + + +class VSharpNetJSSLEngine(JSSLMRIModelEngine): + """Joint Supervised and Self-supervised Learning vSHARP Model 2D Engine. + + Used for the main experiments in the JSSL paper [1]. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + + + References + ---------- + .. [1] Yiasemis, G., Moriakov, N., Sánchez, C.I., Sonke, J.-J., Teuwen, J.: JSSL: Joint Supervised and + Self-supervised Learning for MRI Reconstruction, http://arxiv.org/abs/2311.15856, (2023). + https://doi.org/10.48550/arXiv.2311.15856. + """ + + def __init__( + self, + cfg: BaseConfig, + model: nn.Module, + device: str, + forward_operator: Optional[Callable] = None, + backward_operator: Optional[Callable] = None, + mixed_precision: bool = False, + **models: nn.Module, + ): + """Inits :class:`VSharpNetJSSLEngine`. + + Parameters + ---------- + cfg: BaseConfig + Configuration file. + model: nn.Module + Model. + device: str + Device. Can be "cuda:{idx}" or "cpu". + forward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The forward operator. Default: None. + backward_operator: Callable[[tuple[Any, ...]], torch.Tensor], optional + The backward operator. Default: None. + mixed_precision: bool + Use mixed precision. Default: False. + **models: nn.Module + Additional models. + """ + super().__init__( + cfg, + model, + device, + forward_operator=forward_operator, + backward_operator=backward_operator, + mixed_precision=mixed_precision, + **models, + ) + + def forward_function(self, data: dict[str, Any]) -> None: + """Forward function for :class:`VSharpNetJSSLEngine`.""" + raise NotImplementedError( + "Forward function for JSSL vSHARP is not implemented. `VSharpNetJSSLEngine` " + "implements the `_do_iteration` method itself so the forward function should not be " + "called." + ) + + def _do_iteration( + self, + data: dict[str, Any], + loss_fns: Optional[dict[str, Callable]] = None, + regularizer_fns: Optional[dict[str, Callable]] = None, + ) -> DoIterationOutput: + """This function implements the `_do_iteration` for the JSSL vSHARP model. + + Returns + ------- + DoIterationOutput + Output of the iteration. + + + It assumes different behavior for SSL training and inference. During SSL training, it expects the input data + to contain keys "input_kspace" and "input_sampling_mask", otherwise, it expects the input data to contain + keys "masked_kspace" and "sampling_mask". + + Parameters + ---------- + data : dict[str, Any] + Input data dictionary. The dictionary should contain the following keys: + - "is_ssl": Boolean indicating if the sample is for SSL training. + - "input_kspace" if SSL training, otherwise "masked_kspace". + - "input_sampling_mask" if SSL training, otherwise "sampling_mask". + - "target_sampling_mask": Sampling mask for the target k-space if SSL training. + - "sensitivity_map": Sensitivity map. + - "target": Target image. + - "padding": Padding, optionally. + loss_fns : Optional[dict[str, Callable]], optional + Loss functions, optional. + regularizer_fns : Optional[dict[str, Callable]], optional + Regularizer functions, optional. + + """ + + # loss_fns can be None, e.g. during validation + if loss_fns is None: + loss_fns = {} + + if regularizer_fns is None: + regularizer_fns = {} + + # Move data to device + data = dict_to_device(data, self.device) + + # Get a boolean indicating if the sample is for SSL training + # This will expect the input data to contain the keys "input_kspace" and "input_sampling_mask" if SSL training + is_ssl = data["is_ssl"][0] + + # Get the k-space and mask which differ if SSL training or supervised training + # The also differ during training and inference for SSL + if is_ssl and self.model.training: + kspace, mask = data["input_kspace"], data["input_sampling_mask"] + else: + kspace, mask = data["masked_kspace"], data["sampling_mask"] + + # Initialize loss and regularizer dictionaries + loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()} + regularizer_dict = { + k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys() + } + + with autocast(enabled=self.mixed_precision): + data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"]) + + output_images = self.model( + masked_kspace=kspace, + sampling_mask=mask, + sensitivity_map=data["sensitivity_map"], + ) + + if self.model.training: + if len(output_images) > 1: + # Initialize auxiliary loss weights with a logarithmic scale if multiple auxiliary steps + auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0]) + else: + # Initialize auxiliary loss weights with a single value of 1.0 if single step + auxiliary_loss_weights = torch.ones(1).to(output_images[0]) + + for i in range(len(output_images)): + # Data consistency + output_kspace = T.apply_padding( + kspace + self._forward_operator(output_images[i], data["sensitivity_map"], ~mask), + padding=data.get("padding", None), + ) + if is_ssl: + # Project predicted k-space onto target k-space if SSL + output_kspace = T.apply_mask(output_kspace, data["target_sampling_mask"], return_mask=False) + + # Compute k-space loss per auxiliary step + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, None, output_kspace, auxiliary_loss_weights[i] + ) + + # SENSE reconstruction if SSL else modulus if supervised + output_images[i] = T.modulus( + T.reduce_operator( + self.backward_operator(output_kspace, dim=self._spatial_dims), + data["sensitivity_map"], + self._coil_dim, + ) + if is_ssl + else output_images[i] + ) + + # Compute image loss per auxiliary step + loss_dict = self.compute_loss_on_data( + loss_dict, loss_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + regularizer_dict = self.compute_loss_on_data( + regularizer_dict, regularizer_fns, data, output_images[i], None, auxiliary_loss_weights[i] + ) + + loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore + self._scaler.scale(loss).backward() + + output_image = output_images[-1] + else: + output_image = T.modulus(output_images[-1]) + + loss_dict = detach_dict(loss_dict) # Detach dict, only used for logging. + regularizer_dict = detach_dict(regularizer_dict) + + return DoIterationOutput( + output_image=output_image, + sensitivity_map=data["sensitivity_map"], + data_dict={**loss_dict, **regularizer_dict}, + ) diff --git a/tests/tests_ssl/test_vsharp_engine.py b/tests/tests_ssl/test_vsharp_engine.py new file mode 100644 index 00000000..3cfb9a92 --- /dev/null +++ b/tests/tests_ssl/test_vsharp_engine.py @@ -0,0 +1,194 @@ +# Copyright (c) DIRECT Contributors + +"""Tests for SSL/JSSL engines in direct.nn.vsharp.vsharp_engine module.""" + +import functools + +import numpy as np +import pytest +import torch + +from direct.config.defaults import DefaultConfig, FunctionConfig, LossConfig, TrainingConfig, ValidationConfig +from direct.data.transforms import fft2, ifft2 +from direct.nn.vsharp.config import VSharpNet3DConfig, VSharpNetConfig +from direct.nn.vsharp.vsharp import VSharpNet, VSharpNet3D +from direct.nn.vsharp.vsharp_engine import VSharpNetJSSLEngine, VSharpNetSSLEngine + + +def create_sample(**kwargs): + sample = dict() + for k, v in locals()["kwargs"].items(): + sample[k] = v + return sample + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vsharpnet_ssl_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNetConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_unet_num_filters=num_filters, + image_unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + image_unet_num_filters=model_config.image_unet_num_filters, + image_unet_num_pool_layers=model_config.image_unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNetSSLEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + # Simulate training + # Test _do_iteration function with a single data batch + data = create_sample( + input_sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + target_sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + input_kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + sensitivity_map=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + # Simulate validation + engine.model.eval() + # Test _do_iteration function with a single data batch + data = create_sample( + sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + masked_kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + sensitivity_map=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + +@pytest.mark.parametrize( + "shape", + [(4, 3, 10, 16, 2), (5, 1, 10, 12, 2)], +) +@pytest.mark.parametrize( + "loss_fns", + [["l1_loss", "kspace_nmse_loss", "kspace_nmae_loss"]], +) +@pytest.mark.parametrize( + "num_steps, num_steps_dc_gd, num_filters, num_pool_layers", + [[4, 2, 10, 2]], +) +@pytest.mark.parametrize( + "normalized", + [True, False], +) +def test_vsharpnet_jssl_engine(shape, loss_fns, num_steps, num_steps_dc_gd, num_filters, num_pool_layers, normalized): + # Operators + forward_operator = functools.partial(fft2, centered=True) + backward_operator = functools.partial(ifft2, centered=True) + # Configs + loss_config = LossConfig(losses=[FunctionConfig(loss) for loss in loss_fns]) + training_config = TrainingConfig(loss=loss_config) + validation_config = ValidationConfig(crop=None) + model_config = VSharpNetConfig( + num_steps=num_steps, + num_steps_dc_gd=num_steps_dc_gd, + image_unet_num_filters=num_filters, + image_unet_num_pool_layers=num_pool_layers, + auxiliary_steps=-1, + ) + config = DefaultConfig(training=training_config, validation=validation_config, model=model_config) + # Models + model = VSharpNet( + forward_operator, + backward_operator, + num_steps=model_config.num_steps, + num_steps_dc_gd=model_config.num_steps_dc_gd, + image_unet_num_filters=model_config.image_unet_num_filters, + image_unet_num_pool_layers=model_config.image_unet_num_pool_layers, + auxiliary_steps=model_config.auxiliary_steps, + ) + sensitivity_model = torch.nn.Conv2d(2, 2, kernel_size=1) + # Define engine + engine = VSharpNetJSSLEngine(config, model, "cpu", fft2, ifft2, sensitivity_model=sensitivity_model) + engine.ndim = 2 + + # Simulate training (SSL) + # Test _do_iteration function with a single data batch + data = create_sample( + input_sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + target_sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + input_kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + sensitivity_map=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + is_ssl=torch.ones(shape[0]).bool(), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + # Simulate training (SL) + # Test _do_iteration function with a single data batch + data = create_sample( + sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + masked_kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + sensitivity_map=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + is_ssl=torch.zeros(shape[0]).bool(), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1]) + + # Simulate validation + engine.model.eval() + # Test _do_iteration function with a single data batch + data = create_sample( + sampling_mask=torch.from_numpy(np.random.rand(1, 1, shape[2], shape[3], 1)).round().bool(), + masked_kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + kspace=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + sensitivity_map=torch.from_numpy(np.random.randn(shape[0], shape[1], shape[2], shape[3], 2)).float(), + target=torch.from_numpy(np.random.randn(shape[0], shape[2], shape[3])).float(), + scaling_factor=torch.ones(shape[0]), + is_ssl=torch.zeros(shape[0]).bool(), + ) + loss_fns = engine.build_loss() + out = engine._do_iteration(data, loss_fns) + assert out.output_image.shape == (shape[0],) + tuple(shape[2:-1])