diff --git a/direct/data/datasets_config.py b/direct/data/datasets_config.py index 460dc32b9..89a413d17 100644 --- a/direct/data/datasets_config.py +++ b/direct/data/datasets_config.py @@ -38,7 +38,6 @@ class RandomAugmentationTransformsConfig(BaseConfig): random_flip_probability: float = 0.0 random_reverse_probability: float = 0.0 - @dataclass class NormalizationTransformConfig(BaseConfig): scaling_key: Optional[str] = "masked_kspace" diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 8044e640d..9d2e79513 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -640,14 +640,12 @@ def __call__(self, sample: Dict[str, Any], coil_dim: int = 0) -> Dict[str, Any]: class ReconstructionType(str, Enum): """Reconstruction method for :class:`ComputeImage` transform.""" - - rss = "rss" - complex = "complex" - complex_mod = "complex_mod" - sense = "sense" - sense_mod = "sense_mod" - ifft = "ifft" - + RSS = "rss" + COMPLEX = "complex" + COMPLEX_MOD = "complex_mod" + SENSE = "sense" + SENSE_MOD = "sense_mod" + IFFT = "ifft" class ComputeImageModule(DirectModule): """Compute Image transform.""" @@ -657,7 +655,7 @@ def __init__( kspace_key: KspaceKey, target_key: str, backward_operator: Callable, - type_reconstruction: ReconstructionType = ReconstructionType.rss, + type_reconstruction: ReconstructionType = ReconstructionType.RSS, ) -> None: """Inits :class:`ComputeImageModule`. @@ -670,8 +668,9 @@ def __init__( backward_operator: callable The backward operator, e.g. some form of inverse FFT (centered or uncentered). type_reconstruction: ReconstructionType - Type of reconstruction. Can be "complex", "complex_mod", "sense", "sense_mod", "rss" or "ifft". - Default: ReconstructionType.rss. + Type of reconstruction. Can be ReconstructionType.COMPLEX, ReconstructionType.COMPLEX_MOD, + ReconstructionType.SENSE,, ReconstructionType.SENSE_MOD, ReconstructionType.IFFT. + Default: ReconstructionType.RSS. """ super().__init__() self.backward_operator = backward_operator @@ -697,14 +696,14 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: dim = self.spatial_dims["2D"] if kspace_data.ndim == 5 else self.spatial_dims["3D"] # Get complex-valued data solution image = self.backward_operator(kspace_data, dim=dim) - if self.type_reconstruction == ReconstructionType.ifft: + if self.type_reconstruction == ReconstructionType.IFFT: sample[self.target_key] = image elif self.type_reconstruction in [ - ReconstructionType.complex, - ReconstructionType.complex_mod, + ReconstructionType.COMPLEX, + ReconstructionType.COMPLEX_MOD, ]: sample[self.target_key] = image.sum(self.coil_dim) - elif self.type_reconstruction == ReconstructionType.rss: + elif self.type_reconstruction == ReconstructionType.RSS: sample[self.target_key] = T.root_sum_of_squares(image, dim=self.coil_dim) else: if "sensitivity_map" not in sample: @@ -716,8 +715,8 @@ def forward(self, sample: Dict[str, Any]) -> Dict[str, Any]: self.coil_dim ) if self.type_reconstruction in [ - ReconstructionType.complex_mod, - ReconstructionType.sense_mod, + ReconstructionType.COMPLEX_MOD, + ReconstructionType.SENSE_MOD, ]: sample[self.target_key] = T.modulus(sample[self.target_key], self.complex_dim) return sample @@ -1546,7 +1545,7 @@ def build_post_mri_transforms( sensitivity_maps_espirit_max_iters: Optional[int] = 30, delete_acs_mask: bool = True, delete_kspace: bool = True, - image_recon_type: ReconstructionType = ReconstructionType.rss, + image_recon_type: ReconstructionType = ReconstructionType.RSS, scaling_key: TransformKey = TransformKey.MASKED_KSPACE, scale_percentile: Optional[float] = 0.99, ) -> object: @@ -1584,7 +1583,7 @@ def build_post_mri_transforms( delete_kspace : bool If True will delete key `kspace` (fully sampled k-space). Default: True. image_recon_type : ReconstructionType - Type to reconstruct target image. Default: ReconstructionType.rss. + Type to reconstruct target image. Default: ReconstructionType.RSS. scaling_key : TransformKey Key in sample to scale scalable items in sample. Default: TransformKey.MASKED_KSPACE. scale_percentile : float, optional @@ -1666,7 +1665,7 @@ def build_mri_transforms( sensitivity_maps_espirit_max_iters: Optional[int] = 30, delete_acs_mask: bool = True, delete_kspace: bool = True, - image_recon_type: ReconstructionType = ReconstructionType.rss, + image_recon_type: ReconstructionType = ReconstructionType.RSS, pad_coils: Optional[int] = None, scaling_key: TransformKey = TransformKey.MASKED_KSPACE, scale_percentile: Optional[float] = 0.99, @@ -1737,7 +1736,7 @@ def build_mri_transforms( delete_kspace : bool If True will delete key `kspace` (fully sampled k-space). Default: True. image_recon_type : ReconstructionType - Type to reconstruct target image. Default: ReconstructionType.rss. + Type to reconstruct target image. Default: ReconstructionType.RSS. pad_coils : int Number of coils to pad data to. scaling_key : TransformKey diff --git a/direct/nn/get_nn_model_config.py b/direct/nn/get_nn_model_config.py index 9fc05ed14..c314e864c 100644 --- a/direct/nn/get_nn_model_config.py +++ b/direct/nn/get_nn_model_config.py @@ -26,7 +26,7 @@ def _get_relu_activation(activation: ActivationType = ActivationType.RELU, **kwa """ if activation == ActivationType.PRELU: return nn.PReLU(**kwargs) - if activation == ActivationType.LEAKYRELU: + if activation == ActivationType.LEAKY_RELU: return nn.LeakyReLU(**kwargs) return nn.ReLU(**kwargs) diff --git a/direct/nn/types.py b/direct/nn/types.py index 0e09b7b77..a450180a4 100644 --- a/direct/nn/types.py +++ b/direct/nn/types.py @@ -8,7 +8,7 @@ class ActivationType(DirectEnum): RELU = "relu" PRELU = "prelu" - LEAKYRELU = "leaky_relu" + LEAKY_RELU = "leaky_relu" class ModelName(DirectEnum): @@ -20,9 +20,9 @@ class ModelName(DirectEnum): class InitType(DirectEnum): - INPUTIMAGE = "input_image" + INPUT_IMAGE = "input_image" SENSE = "sense" - ZEROFILLED = "zero_filled" + ZERO_FILLED = "zero_filled" ZEROS = "zeros" diff --git a/direct/nn/varsplitnet/config.py b/direct/nn/varsplitnet/config.py index 73b707afe..2c061dbb2 100644 --- a/direct/nn/varsplitnet/config.py +++ b/direct/nn/varsplitnet/config.py @@ -16,31 +16,31 @@ class MRIVarSplitNetConfig(ModelConfig): kspace_no_parameter_sharing: bool = True image_model_architecture: str = ModelName.UNET kspace_model_architecture: Optional[str] = None - image_resnet_hidden_channels: int = 128 - image_resnet_num_blocks: int = 15 - image_resnet_batchnorm: bool = True - image_resnet_scale: float = 0.1 - image_unet_num_filters: int = 32 - image_unet_num_pool_layers: int = 4 - image_unet_dropout: float = 0.0 - image_didn_hidden_channels: int = 16 - image_didn_num_dubs: int = 6 - image_didn_num_convs_recon: int = 9 - kspace_resnet_hidden_channels: int = 64 - kspace_resnet_num_blocks: int = 1 - kspace_resnet_batchnorm: bool = True - kspace_resnet_scale: float = 0.1 - kspace_unet_num_filters: int = 16 - kspace_unet_num_pool_layers: int = 4 - kspace_unet_dropout: float = 0.0 - kspace_didn_hidden_channels: int = 8 - kspace_didn_num_dubs: int = 6 - kspace_didn_num_convs_recon: int = 9 - image_conv_hidden_channels: int = 64 - image_conv_n_convs: int = 15 - image_conv_activation: str = ActivationType.RELU - image_conv_batchnorm: bool = False - kspace_conv_hidden_channels: int = 64 - kspace_conv_n_convs: int = 15 - kspace_conv_activation: str = ActivationType.PRELU - kspace_conv_batchnorm: bool = False + image_resnet_hidden_channels: Optional[int] = 128 + image_resnet_num_blocks: Optional[int] = 15 + image_resnet_batchnorm: Optional[bool] = True + image_resnet_scale: Optional[float] = 0.1 + image_unet_num_filters: Optional[int] = 32 + image_unet_num_pool_layers: Optional[int] = 4 + image_unet_dropout: Optional[float] = 0.0 + image_didn_hidden_channels: Optional[int] = 16 + image_didn_num_dubs: Optional[int] = 6 + image_didn_num_convs_recon: Optional[int] = 9 + kspace_resnet_hidden_channels: Optional[int] = 64 + kspace_resnet_num_blocks: Optional[int] = 1 + kspace_resnet_batchnorm: Optional[bool] = True + kspace_resnet_scale: Optional[float] = 0.1 + kspace_unet_num_filters: Optional[int] = 16 + kspace_unet_num_pool_layers: Optional[int] = 4 + kspace_unet_dropout: Optional[float] = 0.0 + kspace_didn_hidden_channels: Optional[int] = 8 + kspace_didn_num_dubs: Optional[int] = 6 + kspace_didn_num_convs_recon: Optional[int] = 9 + image_conv_hidden_channels: Optional[int] = 64 + image_conv_n_convs: Optional[int] = 15 + image_conv_activation: Optional[str] = ActivationType.RELU + image_conv_batchnorm: Optional[bool] = False + kspace_conv_hidden_channels: Optional[int] = 64 + kspace_conv_n_convs: Optional[int] = 15 + kspace_conv_activation: Optional[str] = ActivationType.PRELU + kspace_conv_batchnorm: Optional[bool] = False \ No newline at end of file diff --git a/tests/tests_data/test_mri_transforms.py b/tests/tests_data/test_mri_transforms.py index 35b0c70bf..069c6375f 100644 --- a/tests/tests_data/test_mri_transforms.py +++ b/tests/tests_data/test_mri_transforms.py @@ -380,11 +380,11 @@ def test_random_rotation(shape, degree): @pytest.mark.parametrize( "type_recon, complex_output", [ - [ReconstructionType.complex, True], - [ReconstructionType.complex_mod, False], - [ReconstructionType.sense, True], - [ReconstructionType.sense_mod, False], - [ReconstructionType.rss, False], + [ReconstructionType.COMPLEX, True], + [ReconstructionType.COMPLEX_MOD, False], + [ReconstructionType.SENSE, True], + [ReconstructionType.SENSE_MOD, False], + [ReconstructionType.RSS, False], ], ) def test_ComputeImage(shape, type_recon, complex_output): @@ -518,7 +518,7 @@ def test_EstimateSensitivityMap3D( else: transform = EstimateSensitivityMap(**args) if shape[0] == 1 or sense_map_in_sample: - with pytest.warns(None): + with warnings.catch_warnings(record=True): sample = transform(sample) else: sample = transform(sample) diff --git a/tests/tests_nn/test_conjgradnet.py b/tests/tests_nn/test_conjgradnet.py index 77618f86e..4a703fe42 100644 --- a/tests/tests_nn/test_conjgradnet.py +++ b/tests/tests_nn/test_conjgradnet.py @@ -34,7 +34,7 @@ def create_input(shape): @pytest.mark.parametrize( "cg_param_update_type", [CGUpdateType.FR, CGUpdateType.PRP, CGUpdateType.DY, CGUpdateType.BAN] ) -@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZEROFILLED, InitType.ZEROS, "invalid"]) +@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZERO_FILLED, InitType.ZEROS, "invalid"]) @pytest.mark.parametrize("no_parameter_sharing", [True, False]) @pytest.mark.parametrize("cg_iters", [5, 20]) @pytest.mark.parametrize("cg_tol", [1e-2, 1e-8]) diff --git a/tests/tests_nn/test_recurrentvarnet.py b/tests/tests_nn/test_recurrentvarnet.py index 4b4574306..eaa6e10df 100644 --- a/tests/tests_nn/test_recurrentvarnet.py +++ b/tests/tests_nn/test_recurrentvarnet.py @@ -43,7 +43,7 @@ def create_input(shape): "learned_initializer, initializer_initialization, initializer_channels, initializer_dilations", [ [True, InitType.SENSE, (4, 4, 8, 8), (1, 1, 1, 2)], - [True, InitType.ZEROFILLED, (2, 4, 2, 4), (1, 2, 1, 3)], + [True, InitType.ZERO_FILLED, (2, 4, 2, 4), (1, 2, 1, 3)], [False, None, None, None], ], ) diff --git a/tests/tests_nn/test_varsplitnet.py b/tests/tests_nn/test_varsplitnet.py index b6c170fa2..c6646d951 100644 --- a/tests/tests_nn/test_varsplitnet.py +++ b/tests/tests_nn/test_varsplitnet.py @@ -20,7 +20,7 @@ def create_input(shape): @pytest.mark.parametrize("shape", [[4, 3, 32, 32], [4, 5, 40, 20]]) @pytest.mark.parametrize("num_steps_reg", [2, 3]) @pytest.mark.parametrize("num_steps_dc", [1, 4]) -@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZEROFILLED]) +@pytest.mark.parametrize("image_init", [InitType.SENSE, InitType.ZERO_FILLED]) @pytest.mark.parametrize("no_parameter_sharing", [True, False]) @pytest.mark.parametrize( "image_model_architecture, image_model_kwargs", @@ -32,7 +32,7 @@ def create_input(shape): { "image_conv_hidden_channels": 8, "image_conv_n_convs": 3, - "image_conv_activation": ActivationType.LEAKYRELU, + "image_conv_activation": ActivationType.LEAKY_RELU, }, ], ],