From 83ab7fba3c0e5917a95d72d980971a2719210b1d Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Thu, 18 Apr 2024 04:38:12 +0200 Subject: [PATCH] Code quality fixes --- direct/data/mri_transforms.py | 5 ++--- direct/nn/ssl/mri_models.py | 2 +- direct/ssl/ssl.py | 17 +++++++++++------ 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/direct/data/mri_transforms.py b/direct/data/mri_transforms.py index 006dc420..9ebffc54 100644 --- a/direct/data/mri_transforms.py +++ b/direct/data/mri_transforms.py @@ -2055,12 +2055,11 @@ def build_mri_transforms( use_seed=use_seed, ).transforms - mri_transforms += [AddBooleanKeysModule(["is_ssl"], [not transforms_type == TranformsType.SUPERVISED])] + mri_transforms += [AddBooleanKeysModule(["is_ssl"], [transforms_type != TranformsType.SUPERVISED])] if transforms_type == TranformsType.SUPERVISED: return Compose(mri_transforms) - - if transforms_type == TranformsType.SSL_SSDU: + elif transforms_type == TranformsType.SSL_SSDU: mask_splitter_kwargs = { "ratio": mask_split_ratio, "acs_region": mask_split_acs_region, diff --git a/direct/nn/ssl/mri_models.py b/direct/nn/ssl/mri_models.py index fbfebd13..4b2381c5 100644 --- a/direct/nn/ssl/mri_models.py +++ b/direct/nn/ssl/mri_models.py @@ -105,7 +105,7 @@ def log_first_training_example_and_model(self, data: dict[str, Any]) -> None: """ storage = get_event_storage() - self.logger.info(f"First case: slice_no: {data['slice_no'][0]}, filename: {data['filename'][0]}.") + self.logger.info("First case: slice_no: %s, filename: %s.", data["slice_no"][0], data["filename"][0]) if "input_sampling_mask" in data: first_input_sampling_mask = data["input_sampling_mask"][0][0] diff --git a/direct/ssl/ssl.py b/direct/ssl/ssl.py index 7302faa8..993f1796 100644 --- a/direct/ssl/ssl.py +++ b/direct/ssl/ssl.py @@ -202,9 +202,11 @@ def _gaussian_split( center_x = nrow // 2 center_y = ncol // 2 - if self.keep_acs and acs_mask is None: - raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.") - mask = mask.clone() if not self.keep_acs else mask.clone() & (~acs_mask) + if self.keep_acs: + if acs_mask is None: + raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.") + else: + mask = mask & (~acs_mask) with temp_seed(self.rng, seed): if seed is None: @@ -273,9 +275,12 @@ def _uniform_split( center_x = nrow // 2 center_y = ncol // 2 - if self.keep_acs and acs_mask is None: - raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.") - mask = mask.clone() if not self.keep_acs else mask.clone() & (~acs_mask) + if self.keep_acs: + if acs_mask is None: + raise ValueError("`keep_acs` is set to True but not received an input for `acs_mask`.") + else: + mask = mask & (~acs_mask) + temp_mask = mask.cpu().clone() if not self.keep_acs: