From 036af261ee1fc52da073f42b8aec790f3b390777 Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 18 Apr 2024 04:45:43 +0200 Subject: [PATCH] Code quality fixes in mri transforms --- direct/data/mri_transforms.py | 89 +++++++++++++++++------------------ 1 file changed, 42 insertions(+), 47 deletions(-) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 9ebffc54..8c4b20d6 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -1763,7 +1763,7 @@ def build_supervised_mri_transforms( DirectTransform An MRI transformation object. """ - # TODO: Use seed + # pylint: disable=too-many-arguments mri_transforms: list[Callable] = [ToTensor()] if crop: @@ -2059,53 +2059,48 @@ def build_mri_transforms( if transforms_type == TranformsType.SUPERVISED: return Compose(mri_transforms) - elif transforms_type == TranformsType.SSL_SSDU: - mask_splitter_kwargs = { - "ratio": mask_split_ratio, - "acs_region": mask_split_acs_region, - "keep_acs": mask_split_keep_acs, - "use_seed": use_seed, - "kspace_key": KspaceKey.MASKED_KSPACE, - } - mri_transforms += [ - ( - GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std) - if mask_split_type == MaskSplitterType.GAUSSIAN - else ( - UniformMaskSplitter(**mask_splitter_kwargs) - if mask_split_type == MaskSplitterType.UNIFORM - else HalfMaskSplitterModule( - **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"}, - direction=mask_split_half_direction, - ) - ) - ), - DeleteKeys([TransformKey.ACS_MASK]), - ] - - mri_transforms += [ - RenameKeys( - [ - SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE, - SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE, - ], - ["input_kspace", "kspace"], - ), - DeleteKeys(["masked_kspace", "sampling_mask"]), - ] # Rename keys for SSL engine - mri_transforms += [ - ComputeImage( - kspace_key=KspaceKey.KSPACE, - target_key=TransformKey.TARGET, - backward_operator=backward_operator, - type_reconstruction=image_recon_type, + mask_splitter_kwargs = { + "ratio": mask_split_ratio, + "acs_region": mask_split_acs_region, + "keep_acs": mask_split_keep_acs, + "use_seed": use_seed, + "kspace_key": KspaceKey.MASKED_KSPACE, + } + mri_transforms += [ + ( + GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std) + if mask_split_type == MaskSplitterType.GAUSSIAN + else ( + UniformMaskSplitter(**mask_splitter_kwargs) + if mask_split_type == MaskSplitterType.UNIFORM + else HalfMaskSplitterModule( + **{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"}, + direction=mask_split_half_direction, + ) ) - ] + ), + DeleteKeys([TransformKey.ACS_MASK]), + ] - return Compose(mri_transforms) - else: - raise NotImplementedError( - f"Currently only TransformsType.SUPERVISED or TranformsType.SSL_SSDU is supported as input for " - f"`transforms_type`. Received: {transforms_type}." + mri_transforms += [ + RenameKeys( + [ + SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE, + SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE, + ], + ["input_kspace", "kspace"], + ), + DeleteKeys(["masked_kspace", "sampling_mask"]), + ] # Rename keys for SSL engine + + mri_transforms += [ + ComputeImage( + kspace_key=KspaceKey.KSPACE, + target_key=TransformKey.TARGET, + backward_operator=backward_operator, + type_reconstruction=image_recon_type, ) + ] + + return Compose(mri_transforms)