-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change to Dataset as in brain segmentation bundle, add support for am…
…p, add validate in train_autoencoder.json Signed-off-by: Can-Zhao <[email protected]>
- Loading branch information
Showing
3 changed files
with
186 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
models/brats_mri_generative_diffusion/scripts/prepare_datalist.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import argparse | ||
import glob | ||
import json | ||
import os | ||
|
||
import monai | ||
from sklearn.model_selection import train_test_split | ||
|
||
|
||
def produce_sample_dict(line: str): | ||
names = os.listdir(line) | ||
seg, t1ce, t1, t2, flair = [], [], [], [], [] | ||
for name in names: | ||
name = os.path.join(line, name) | ||
if "_seg.nii" in name: | ||
seg.append(name) | ||
elif "_t1ce.nii" in name: | ||
t1ce.append(name) | ||
elif "_t1.nii" in name: | ||
t1.append(name) | ||
elif "_t2.nii" in name: | ||
t2.append(name) | ||
elif "_flair.nii" in name: | ||
flair.append(name) | ||
|
||
return {"label": seg[0], "image": t1ce + t1 + t2 + flair} | ||
|
||
|
||
def produce_datalist(dataset_dir: str, train_size: int = 200): | ||
""" | ||
This function is used to split the dataset. | ||
It will produce "train_size" number of samples for training, and the other samples | ||
are divided equally into val and test sets. | ||
""" | ||
|
||
samples = sorted(glob.glob(os.path.join(dataset_dir, "*", "*"), recursive=True)) | ||
datalist = [] | ||
for line in samples: | ||
datalist.append(produce_sample_dict(line)) | ||
train_list, other_list = train_test_split(datalist, train_size=train_size) | ||
val_list, test_list = train_test_split(other_list, train_size=0.5) | ||
|
||
return {"training": train_list, "validation": val_list, "testing": test_list} | ||
|
||
|
||
def main(args): | ||
""" | ||
split the dataset and output the data list into a json file. | ||
""" | ||
data_file_base_dir = os.path.join(os.path.abspath(args.path), "training") | ||
# produce deterministic data splits | ||
monai.utils.set_determinism(seed=123) | ||
datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size) | ||
with open(args.output, "w") as f: | ||
json.dump(datalist, f) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="") | ||
parser.add_argument( | ||
"--path", | ||
type=str, | ||
default="/workspace/data/medical/brats2018challenge", | ||
help="root path of brats 2018 dataset.", | ||
) | ||
parser.add_argument( | ||
"--output", type=str, default="configs/datalist.json", help="relative path of output datalist json file." | ||
) | ||
parser.add_argument("--train_size", type=int, default=200, help="number of training samples.") | ||
args = parser.parse_args() | ||
|
||
main(args) |