Skip to content

Commit

Permalink
Code quality fixes in mri transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Apr 18, 2024
1 parent 49577ad commit 036af26
Showing 1 changed file with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions direct/data/mri_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,7 +1763,7 @@ def build_supervised_mri_transforms(
DirectTransform
An MRI transformation object.
"""
# TODO: Use seed
# pylint: disable=too-many-arguments

mri_transforms: list[Callable] = [ToTensor()]
if crop:
Expand Down Expand Up @@ -2059,53 +2059,48 @@ def build_mri_transforms(

if transforms_type == TranformsType.SUPERVISED:
return Compose(mri_transforms)
elif transforms_type == TranformsType.SSL_SSDU:
mask_splitter_kwargs = {
"ratio": mask_split_ratio,
"acs_region": mask_split_acs_region,
"keep_acs": mask_split_keep_acs,
"use_seed": use_seed,
"kspace_key": KspaceKey.MASKED_KSPACE,
}
mri_transforms += [
(
GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std)
if mask_split_type == MaskSplitterType.GAUSSIAN
else (
UniformMaskSplitter(**mask_splitter_kwargs)
if mask_split_type == MaskSplitterType.UNIFORM
else HalfMaskSplitterModule(
**{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"},
direction=mask_split_half_direction,
)
)
),
DeleteKeys([TransformKey.ACS_MASK]),
]

mri_transforms += [
RenameKeys(
[
SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE,
SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE,
],
["input_kspace", "kspace"],
),
DeleteKeys(["masked_kspace", "sampling_mask"]),
] # Rename keys for SSL engine

mri_transforms += [
ComputeImage(
kspace_key=KspaceKey.KSPACE,
target_key=TransformKey.TARGET,
backward_operator=backward_operator,
type_reconstruction=image_recon_type,
mask_splitter_kwargs = {
"ratio": mask_split_ratio,
"acs_region": mask_split_acs_region,
"keep_acs": mask_split_keep_acs,
"use_seed": use_seed,
"kspace_key": KspaceKey.MASKED_KSPACE,
}
mri_transforms += [
(
GaussianMaskSplitter(**mask_splitter_kwargs, std_scale=mask_split_gaussian_std)
if mask_split_type == MaskSplitterType.GAUSSIAN
else (
UniformMaskSplitter(**mask_splitter_kwargs)
if mask_split_type == MaskSplitterType.UNIFORM
else HalfMaskSplitterModule(
**{k: v for k, v in mask_splitter_kwargs.items() if k != "ratio"},
direction=mask_split_half_direction,
)
)
]
),
DeleteKeys([TransformKey.ACS_MASK]),
]

return Compose(mri_transforms)
else:
raise NotImplementedError(
f"Currently only TransformsType.SUPERVISED or TranformsType.SSL_SSDU is supported as input for "
f"`transforms_type`. Received: {transforms_type}."
mri_transforms += [
RenameKeys(
[
SSLTransformMaskPrefixes.INPUT_ + TransformKey.MASKED_KSPACE,
SSLTransformMaskPrefixes.TARGET_ + TransformKey.MASKED_KSPACE,
],
["input_kspace", "kspace"],
),
DeleteKeys(["masked_kspace", "sampling_mask"]),
] # Rename keys for SSL engine

mri_transforms += [
ComputeImage(
kspace_key=KspaceKey.KSPACE,
target_key=TransformKey.TARGET,
backward_operator=backward_operator,
type_reconstruction=image_recon_type,
)
]

return Compose(mri_transforms)

0 comments on commit 036af26

Please sign in to comment.