Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SSL and JSSL functionality #277

Merged
merged 31 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b6f566a
Add base ssl functionality
georgeyiasemis Apr 11, 2024
540b3a2
Add SSL transforms
georgeyiasemis Apr 11, 2024
f0c9f1d
Updated ssl tests
georgeyiasemis Apr 11, 2024
8d46dbb
Minor fix
georgeyiasemis Apr 11, 2024
d7a79ea
Initial ssl models
georgeyiasemis Apr 11, 2024
2c7840c
Small changes for SSL and test
georgeyiasemis Apr 11, 2024
cc7ea05
black
georgeyiasemis Apr 11, 2024
db41ad9
black
georgeyiasemis Apr 11, 2024
b6cabb9
Import annotations
georgeyiasemis Apr 11, 2024
0f76fc9
Add docstrings
georgeyiasemis Apr 11, 2024
fbe66a2
Minor docstrings changes
georgeyiasemis Apr 16, 2024
7eaba03
Add JSSL initial engines
georgeyiasemis Apr 17, 2024
92c0ef9
SSL fixes
georgeyiasemis Apr 17, 2024
57ca113
Unet ssl/jssl tests
georgeyiasemis Apr 18, 2024
851f289
is_ssl_training -> is_ssl
georgeyiasemis Apr 18, 2024
c8c5532
Varnet ssl and jssl engines
georgeyiasemis Apr 18, 2024
d31199e
Minor fix
georgeyiasemis Apr 18, 2024
6404ac7
Minor fixes
georgeyiasemis Apr 18, 2024
daa28c5
Code quality fixes
georgeyiasemis Apr 18, 2024
83ab7fb
Code quality fixes
georgeyiasemis Apr 18, 2024
49577ad
Code quality fixes
georgeyiasemis Apr 18, 2024
036af26
Code quality fixes in mri transforms
georgeyiasemis Apr 18, 2024
5b2629a
Remove useless option - new pylint
georgeyiasemis Apr 18, 2024
dc0ae07
Codacy quality fixes
georgeyiasemis Apr 18, 2024
67f0222
Where to put disable msg?
georgeyiasemis Apr 18, 2024
68e30bb
Add docstrings
georgeyiasemis Apr 19, 2024
8a52b24
Add reference
georgeyiasemis Apr 19, 2024
0d81470
Enum typing doesn't require checks
georgeyiasemis Apr 19, 2024
50aca9a
Omegaconf doesn't accept future annotations
georgeyiasemis Apr 19, 2024
2572474
Test fix
georgeyiasemis Apr 19, 2024
924bc5f
Minor fix
georgeyiasemis Apr 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion direct/config/defaults.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

from dataclasses import dataclass, field
Expand Down Expand Up @@ -102,6 +101,7 @@ class InferenceConfig(BaseConfig):
@dataclass
class ModelConfig(BaseConfig):
model_name: str = MISSING
engine_name: Optional[str] = None


@dataclass
Expand Down
78 changes: 74 additions & 4 deletions direct/data/datasets_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

"""Classes holding the typed configurations for the datasets."""
Expand All @@ -10,6 +9,14 @@

from direct.common.subsample_config import MaskingConfig
from direct.config.defaults import BaseConfig
from direct.data.mri_transforms import (
HalfSplitType,
MaskSplitterType,
RandomFlipType,
ReconstructionType,
SensitivityMapType,
TransformsType,
)


@dataclass
Expand All @@ -22,7 +29,7 @@ class CropTransformConfig(BaseConfig):
@dataclass
class SensitivityMapEstimationTransformConfig(BaseConfig):
estimate_sensitivity_maps: bool = True
sensitivity_maps_type: str = "rss_estimate"
sensitivity_maps_type: SensitivityMapType = SensitivityMapType.RSS_ESTIMATE
sensitivity_maps_espirit_threshold: Optional[float] = 0.05
sensitivity_maps_espirit_kernel_size: Optional[int] = 6
sensitivity_maps_espirit_crop: Optional[float] = 0.95
Expand All @@ -34,7 +41,7 @@ class SensitivityMapEstimationTransformConfig(BaseConfig):
class RandomAugmentationTransformsConfig(BaseConfig):
random_rotation_degrees: Tuple[int, ...] = (-90, 90)
random_rotation_probability: float = 0.0
random_flip_type: Optional[str] = "random"
random_flip_type: Optional[RandomFlipType] = RandomFlipType.RANDOM
random_flip_probability: float = 0.0
random_reverse_probability: float = 0.0

Expand All @@ -47,6 +54,61 @@ class NormalizationTransformConfig(BaseConfig):

@dataclass
class TransformsConfig(BaseConfig):
"""Configuration for the transforms.

Attributes
----------
masking : MaskingConfig
Configuration for the masking.
cropping : CropTransformConfig
Configuration for the cropping.
random_augmentations : RandomAugmentationTransformsConfig
Configuration for the random augmentations.
padding_eps : float
Padding epsilon. Default is 0.001.
estimate_body_coil_image : bool
Estimate body coil image. Default is False.
sensitivity_map_estimation : SensitivityMapEstimationTransformConfig
Configuration for the sensitivity map estimation.
normalization : NormalizationTransformConfig
Configuration for the normalization.
delete_acs_mask : bool
Delete ACS mask after its use. Default is True.
delete_kspace : bool
Delete k-space after its use. This should be set to False if the k-space is needed for the loss computation.
Default is True.
image_recon_type : ReconstructionType
Image reconstruction type. Default is ReconstructionType.RSS.
pad_coils : int, optional
Pad coils. Default is None.
use_seed : bool
Use seed for the transforms. Typically this should be set to True for reproducibility (e.g. inference),
and False for training. Default is True.
transforms_type : TransformsType
Type of transforms. By default the transforms are set for supervised learning (`TransformsType.SUPERVISED`).
To use SSL transforms, set transforms_type to `SSL_SSDU`. This will require additional parameters to be set:
mask_split_ratio, mask_split_acs_region, mask_split_keep_acs, mask_split_type, mask_split_gaussian_std.
Default is `TransformsType.SUPERVISED`.
mask_split_ratio : Tuple[float, ...]
Ratio of the mask to split into input and target mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0.4,).
mask_split_acs_region : Tuple[int, int]
Region of the ACS k-space to keep in the input mask. Ignored if transforms_type is not `SSL_SSDU`.
Default is (0, 0).
mask_split_keep_acs : bool, optional
Keep ACS in both masks, input and target. Ignored if transforms_type is not `SSL_SSDU`. Default is False.
mask_split_type : MaskSplitterType
Type of mask splitting if transforms_type is `SSL_SSDU`. Ignored if transforms_type is not SSL_SSDU.
Default is `MaskSplitterType.GAUSSIAN`.
mask_split_gaussian_std : float
Standard deviation of the Gaussian mask splitter. Ignored if mask_split_type is not `MaskSplitterType.GAUSSIAN`.
Ignored if transforms_type is not `SSL_SSDU`. Default is 3.0.
mask_split_half_direction : HalfSplitType
Direction to split the mask if mask_split_type is `MaskSplitterType.HALF`.
Ignored if MaskSplitterType is not `HALF` or transforms_type is not `SSL_SSDU`.
Default is `HalfSplitType.VERTICAL`.
"""

masking: Optional[MaskingConfig] = MaskingConfig()
cropping: CropTransformConfig = CropTransformConfig()
random_augmentations: RandomAugmentationTransformsConfig = RandomAugmentationTransformsConfig()
Expand All @@ -56,9 +118,17 @@ class TransformsConfig(BaseConfig):
normalization: NormalizationTransformConfig = NormalizationTransformConfig()
delete_acs_mask: bool = True
delete_kspace: bool = True
image_recon_type: str = "rss"
image_recon_type: ReconstructionType = ReconstructionType.RSS
pad_coils: Optional[int] = None
use_seed: bool = True
transforms_type: TransformsType = TransformsType.SUPERVISED
# Next attributes are for the mask splitter in case of transforms_type is set to SSL_SSDU
mask_split_ratio: Tuple[float, ...] = (0.4,)
mask_split_acs_region: Tuple[int, int] = (0, 0)
mask_split_keep_acs: Optional[bool] = False
mask_split_type: MaskSplitterType = MaskSplitterType.GAUSSIAN
mask_split_gaussian_std: float = 3.0
mask_split_half_direction: HalfSplitType = HalfSplitType.VERTICAL


@dataclass
Expand Down
Loading
Loading