Skip to content

Commit

Permalink
Fix possible nested lists
Browse files Browse the repository at this point in the history
  • Loading branch information
EmmaRenauld committed Feb 14, 2024
1 parent 37cdbb8 commit e02b44f
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _verify_subjects_list(self):
"testing set!".format(ignored_subj))
return unique_subjs


def _check_files_presence(self):
"""
Verifying now the list of files. Prevents stopping after a long
Expand All @@ -320,20 +321,29 @@ def _check_files_presence(self):
"""
logging.debug("Verifying files presence")

def flatten_list(a_list):
new_list = []
for element in a_list:
if isinstance(element, list):
new_list.extend(flatten_list(element))
else:
new_list.append(element)
return new_list

# concatenating files from all groups files:
# sum: concatenates list of sub-lists
config_file_list = sum(nested_lookup('files', self.groups_config), [])
config_file_list += nested_lookup(
'connectivity_matrix', self.groups_config)
config_file_list += nested_lookup('std_mask', self.groups_config)
config_file_list = [
nested_lookup('files', self.groups_config),
nested_lookup('connectivity_matrix', self.groups_config),
nested_lookup('std_mask', self.groups_config)]
config_file_list = flatten_list(config_file_list)

for subj_id in self.all_subjs:
subj_input_dir = Path(self.root_folder).joinpath(subj_id)

# Find subject's files from group_config
config_file_list = format_filelist(config_file_list,
self.enforce_files_presence,
folder=subj_input_dir)
_ = format_filelist(config_file_list,
self.enforce_files_presence,
folder=subj_input_dir)

def create_database(self):
"""
Expand Down Expand Up @@ -465,7 +475,9 @@ def _process_one_volume_group(self, group: str, subj_id: str,
if isinstance(std_masks, str):
std_masks = [std_masks]

std_masks = format_filelist(std_masks, folder=subj_input_dir)
std_masks = format_filelist(std_masks,
self.enforce_files_presence,
folder=subj_input_dir)
for mask in std_masks:
logging.info(" - Loading standardization mask {}"
.format(os.path.basename(mask)))
Expand Down

0 comments on commit e02b44f

Please sign in to comment.