Skip to content

Commit

Permalink
Add SSL and JSSL functionality (#277)
Browse files Browse the repository at this point in the history
* Add SSL methods
* Add our JSSL model 
* Improve code quality and improve tests
  • Loading branch information
georgeyiasemis authored Apr 26, 2024
1 parent 0a25a6a commit 66dd5cd
Show file tree
Hide file tree
Showing 19 changed files with 2,843 additions and 185 deletions.
2 changes: 1 addition & 1 deletion direct/config/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass, field
Expand Down Expand Up @@ -102,6 +101,7 @@ class InferenceConfig(BaseConfig):
@dataclass
class ModelConfig(BaseConfig):
model_name: str = MISSING
engine_name: Optional[str] = None


@dataclass
Expand Down
78 changes: 74 additions & 4 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Classes holding the typed configurations for the datasets."""
Expand All @@ -10,6 +9,14 @@

from direct.common.subsample_config import MaskingConfig
from direct.config.defaults import BaseConfig
from direct.data.mri_transforms import (
HalfSplitType,
MaskSplitterType,
RandomFlipType,
ReconstructionType,
SensitivityMapType,
TransformsType,
)


@dataclass
Expand All @@ -22,7 +29,7 @@ class CropTransformConfig(BaseConfig):
@dataclass
class SensitivityMapEstimationTransformConfig(BaseConfig):
estimate_sensitivity_maps: bool = True
sensitivity_maps_type: str = "rss_estimate"
sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE
sensitivity_maps_espirit_threshold: Optional[float] = 0.05
sensitivity_maps_espirit_kernel_size: Optional[int] = 6
sensitivity_maps_espirit_crop: Optional[float] = 0.95
Expand All @@ -34,7 +41,7 @@ class SensitivityMapEstimationTransformConfig(BaseConfig):
class RandomAugmentationTransformsConfig(BaseConfig):
random_rotation_degrees: Tuple[int, ...] = (-90, 90)
random_rotation_probability: float = 0.0
random_flip_type: Optional[str] = "random"
random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM
random_flip_probability: float = 0.0
random_reverse_probability: float = 0.0

Expand All @@ -47,6 +54,61 @@ class NormalizationTransformConfig(BaseConfig):

@dataclass
class TransformsConfig(BaseConfig):
"""Configuration for the transforms.
Attributes
----------
masking : MaskingConfig
Configuration for the masking.
cropping : CropTransformConfig
Configuration for the cropping.
random_augmentations : RandomAugmentationTransformsConfig
Configuration for the random augmentations.
padding_eps : float
Padding epsilon. Default is 0.001.
estimate_body_coil_image : bool
Estimate body coil image. Default is False.
sensitivity_map_estimation : SensitivityMapEstimationTransformConfig
Configuration for the sensitivity map estimation.
normalization : NormalizationTransformConfig
Configuration for the normalization.
delete_acs_mask : bool
Delete ACS mask after its use. Default is True.
delete_kspace : bool
Delete k-space after its use. This should be set to False if the k-space is needed for the loss computation.
Default is True.
image_recon_type : ReconstructionType
Image reconstruction type. Default is ReconstructionType.RSS.
pad_coils : int, optional
Pad coils. Default is None.
use_seed : bool
Use seed for the transforms. Typically this should be set to True for reproducibility (e.g. inference),
and False for training. Default is True.
transforms_type : TransformsType
Type of transforms. By default the transforms are set for supervised learning (`TransformsType.SUPERVISED`).
To use SSL transforms, set transforms_type to `SSL_SSDU`. This will require additional parameters to be set:
mask_split_ratio, mask_split_acs_region, mask_split_keep_acs, mask_split_type, mask_split_gaussian_std.
Default is `TransformsType.SUPERVISED`.
mask_split_ratio : Tuple[float, ...]
Ratio of the mask to split into input and target mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0.4,).
mask_split_acs_region : Tuple[int, int]
Region of the ACS k-space to keep in the input mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0, 0).
mask_split_keep_acs : bool, optional
Keep ACS in both masks, input and target. Ignored if transforms_type is not `SSL_SSDU`. Default is False.
mask_split_type : MaskSplitterType
Type of mask splitting if transforms_type is `SSL_SSDU`. Ignored if transforms_type is not SSL_SSDU.
Default is `MaskSplitterType.GAUSSIAN`.
mask_split_gaussian_std : float
Standard deviation of the Gaussian mask splitter. Ignored if mask_split_type is not `MaskSplitterType.GAUSSIAN`.
Ignored if transforms_type is not `SSL_SSDU`. Default is 3.0.
mask_split_half_direction : HalfSplitType
Direction to split the mask if mask_split_type is `MaskSplitterType.HALF`.
Ignored if MaskSplitterType is not `HALF` or transforms_type is not `SSL_SSDU`.
Default is `HalfSplitType.VERTICAL`.
"""

masking: Optional[MaskingConfig] = MaskingConfig()
cropping: CropTransformConfig = CropTransformConfig()
random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig()
Expand All @@ -56,9 +118,17 @@ class TransformsConfig(BaseConfig):
normalization: NormalizationTransformConfig = NormalizationTransformConfig()
delete_acs_mask: bool = True
delete_kspace: bool = True
image_recon_type: str = "rss"
image_recon_type: ReconstructionType = ReconstructionType.RSS
pad_coils: Optional[int] = None
use_seed: bool = True
transforms_type: TransformsType = TransformsType.SUPERVISED
# Next attributes are for the mask splitter in case of transforms_type is set to SSL_SSDU
mask_split_ratio: Tuple[float, ...] = (0.4,)
mask_split_acs_region: Tuple[int, int] = (0, 0)
mask_split_keep_acs: Optional[bool] = False
mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN
mask_split_gaussian_std: float = 3.0
mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL


@dataclass
Expand Down
Loading

0 comments on commit 66dd5cd

Please sign in to comment.