diff --git a/monai/create_msd_data.py b/monai/create_msd_data.py index 2c2fbf3..a723be7 100644 --- a/monai/create_msd_data.py +++ b/monai/create_msd_data.py @@ -233,6 +233,12 @@ def main(): # sort the subjects train_subjects, val_subjects, test_subjects = sorted(train_subjects), sorted(val_subjects), sorted(test_subjects) + # add a column specifying whether the subject is in train, val or test split + df['split'] = 'none' + df.loc[df['subjectID'].isin(train_subjects), 'split'] = 'train' + df.loc[df['subjectID'].isin(val_subjects), 'split'] = 'validation' + df.loc[df['subjectID'].isin(test_subjects), 'split'] = 'test' + # get boilerplate json params = get_boilerplate_json(dataset_name, dataset_commits)