diff --git a/models/brats_mri_generative_diffusion/configs/inference.json b/models/brats_mri_generative_diffusion/configs/inference.json index 22767e98..d4952061 100644 --- a/models/brats_mri_generative_diffusion/configs/inference.json +++ b/models/brats_mri_generative_diffusion/configs/inference.json @@ -12,12 +12,12 @@ "output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')", "spatial_dims": 3, "image_channels": 1, - "latent_channels": 8, + "latent_channels": 4, "latent_shape": [ - 8, - 36, - 44, - 28 + "@latent_channels", + 48, + 48, + 32 ], "autoencoder_def": { "_target_": "generative.networks.nets.AutoencoderKL", @@ -39,7 +39,9 @@ false ], "with_encoder_nonlocal_attn": false, - "with_decoder_nonlocal_attn": false + "with_decoder_nonlocal_attn": false, + "use_checkpointing": true, + "use_convtranspose": false }, "network_def": { "_target_": "generative.networks.nets.DiffusionModelUNet", @@ -47,7 +49,7 @@ "in_channels": "@latent_channels", "out_channels": "@latent_channels", "num_channels": [ - 256, + 128, 256, 512 ], @@ -58,10 +60,11 @@ ], "num_head_channels": [ 0, - 64, - 64 + 32, + 32 ], - "num_res_blocks": 2 + "num_res_blocks": 2, + "use_flash_attention": true }, "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", diff --git a/models/brats_mri_generative_diffusion/configs/inference_autoencoder.json b/models/brats_mri_generative_diffusion/configs/inference_autoencoder.json index eb66dee2..0bbffffc 100644 --- a/models/brats_mri_generative_diffusion/configs/inference_autoencoder.json +++ b/models/brats_mri_generative_diffusion/configs/inference_autoencoder.json @@ -2,11 +2,14 @@ "imports": [ "$import torch", "$from datetime import datetime", - "$from pathlib import Path" + "$from pathlib import Path", + "$import generative" ], "bundle_root": ".", "model_dir": "$@bundle_root + '/models'", - "dataset_dir": "/workspace/data/medical", + "data_list_file_path": "$@bundle_root + '/configs/datalist.json'", + "dataset_dir": "/datasets/brats18", + "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='testing', base_dir=@dataset_dir)", "output_dir": "$@bundle_root + '/output'", "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", @@ -20,11 +23,11 @@ ], "spatial_dims": 3, "image_channels": 1, - "latent_channels": 8, + "latent_channels": 4, "infer_patch_size": [ - 144, - 176, - 112 + 192, + 192, + 128 ], "autoencoder_def": { "_target_": "generative.networks.nets.AutoencoderKL", @@ -46,7 +49,9 @@ false ], "with_encoder_nonlocal_attn": false, - "with_decoder_nonlocal_attn": false + "with_decoder_nonlocal_attn": false, + "use_checkpointing": true, + "use_convtranspose": false }, "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", @@ -108,13 +113,8 @@ "transforms": "$@preprocessing_transforms + @crop_transforms + @final_transforms" }, "dataset": { - "_target_": "monai.apps.DecathlonDataset", - "root_dir": "@dataset_dir", - "task": "Task01_BrainTumour", - "section": "validation", - "cache_rate": 0.0, - "num_workers": 8, - "download": false, + "_target_": "Dataset", + "data": "@test_datalist", "transform": "@preprocessing" }, "dataloader": { diff --git a/models/brats_mri_generative_diffusion/configs/train_autoencoder.json b/models/brats_mri_generative_diffusion/configs/train_autoencoder.json index ddaa8c41..ab955068 100644 --- a/models/brats_mri_generative_diffusion/configs/train_autoencoder.json +++ b/models/brats_mri_generative_diffusion/configs/train_autoencoder.json @@ -2,22 +2,35 @@ "imports": [ "$import functools", "$import glob", - "$import scripts" + "$import scripts", + "$import generative" ], "bundle_root": ".", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "ckpt_dir": "$@bundle_root + '/models'", "tf_dir": "$@bundle_root + '/eval'", - "dataset_dir": "/workspace/data/medical", + "data_list_file_path": "$@bundle_root + '/configs/datalist.json'", + "dataset_dir": "/datasets/brats18", + "train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training', base_dir=@dataset_dir)", + "val_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation', base_dir=@dataset_dir)", "pretrained": false, "perceptual_loss_model_weights_path": null, - "train_batch_size": 2, - "lr": 1e-05, + "train_batch_size": 4, + "val_batch_size": 3, + "epochs": 3000, + "val_interval": 10, + "lr": 5e-05, + "amp": true, "train_patch_size": [ 112, - 128, + 112, 80 ], + "val_patch_size": [ + 192, + 192, + 128 + ], "channel": 0, "spacing": [ 1.1, @@ -26,7 +39,7 @@ ], "spatial_dims": 3, "image_channels": 1, - "latent_channels": 8, + "latent_channels": 4, "discriminator_def": { "_target_": "generative.networks.nets.PatchDiscriminator", "spatial_dims": "@spatial_dims", @@ -56,7 +69,9 @@ false ], "with_encoder_nonlocal_attn": false, - "with_decoder_nonlocal_attn": false + "with_decoder_nonlocal_attn": false, + "use_checkpointing": true, + "use_convtranspose": false }, "perceptual_loss_def": { "_target_": "generative.losses.PerceptualLoss", @@ -114,9 +129,12 @@ "keys": "image", "pixdim": "@spacing", "mode": "bilinear" - } - ], - "final_transforms": [ + }, + { + "_target_": "CenterSpatialCropd", + "keys": "image", + "roi_size": "@val_patch_size" + }, { "_target_": "ScaleIntensityRangePercentilesd", "keys": "image", @@ -137,17 +155,13 @@ ], "preprocessing": { "_target_": "Compose", - "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" + "transforms": "$@preprocessing_transforms + @train#crop_transforms" }, "dataset": { - "_target_": "monai.apps.DecathlonDataset", - "root_dir": "@dataset_dir", - "task": "Task01_BrainTumour", - "section": "training", - "cache_rate": 1.0, - "num_workers": 8, - "download": false, - "transform": "@train#preprocessing" + "_target_": "CacheDataset", + "data": "@train_datalist", + "transform": "@train#preprocessing", + "cache_rate": 1.0 }, "dataloader": { "_target_": "DataLoader", @@ -158,32 +172,33 @@ }, "handlers": [ { - "_target_": "CheckpointSaver", - "save_dir": "@ckpt_dir", - "save_dict": { - "model": "@gnetwork" - }, - "save_interval": 0, - "save_final": true, + "_target_": "ValidationHandler", + "validator": "@validate#evaluator", "epoch_level": true, - "final_filename": "model_autoencoder.pt" + "interval": "@val_interval" }, { "_target_": "StatsHandler", "tag_name": "train_loss", - "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" + "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]+monai.handlers.from_engine(['d_loss'], first=True)(x)[0]" }, { "_target_": "TensorBoardStatsHandler", "log_dir": "@tf_dir", - "tag_name": "train_loss", + "tag_name": "train_generator_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" + }, + { + "_target_": "TensorBoardStatsHandler", + "log_dir": "@tf_dir", + "tag_name": "train_discriminator_loss", + "output_transform": "$lambda x: monai.handlers.from_engine(['d_loss'], first=True)(x)[0]" } ], "trainer": { "_target_": "scripts.ldm_trainer.VaeGanTrainer", "device": "@device", - "max_epochs": 1500, + "max_epochs": "@epochs", "train_data_loader": "@train#dataloader", "g_network": "@gnetwork", "g_optimizer": "@goptimizer", @@ -195,7 +210,76 @@ "g_update_latents": true, "latent_shape": "@latent_channels", "key_train_metric": "$None", - "train_handlers": "@train#handlers" + "train_handlers": "@train#handlers", + "amp": "@amp" + } + }, + "validate": { + "preprocessing": { + "_target_": "Compose", + "transforms": "$@preprocessing_transforms" + }, + "dataset": { + "_target_": "CacheDataset", + "data": "@val_datalist", + "transform": "@validate#preprocessing", + "cache_rate": 1.0 + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@validate#dataset", + "batch_size": "@val_batch_size", + "shuffle": false, + "num_workers": 4 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Lambdad", + "keys": "pred", + "func": "$lambda x: x[0]" + } + ] + }, + "handlers": [ + { + "_target_": "StatsHandler", + "iteration_log": false + }, + { + "_target_": "TensorBoardStatsHandler", + "log_dir": "@tf_dir", + "iteration_log": false + }, + { + "_target_": "CheckpointSaver", + "save_dir": "@ckpt_dir", + "save_dict": { + "model": "@gnetwork" + }, + "save_interval": 0, + "save_final": true, + "epoch_level": true, + "final_filename": "model_autoencoder.pt" + } + ], + "key_metric": { + "val_mean_l2": { + "_target_": "MeanSquaredError", + "output_transform": "$monai.handlers.from_engine(['pred', 'image'])" + } + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@validate#dataloader", + "network": "@gnetwork", + "postprocessing": "@validate#postprocessing", + "key_val_metric": "$@validate#key_metric", + "metric_cmp_fn": "$lambda current_metric,prev_best: current_metric < prev_best", + "val_handlers": "@validate#handlers", + "amp": "@amp" } }, "initialize": [ diff --git a/models/brats_mri_generative_diffusion/configs/train_diffusion.json b/models/brats_mri_generative_diffusion/configs/train_diffusion.json index 85c8ca8a..fc723495 100644 --- a/models/brats_mri_generative_diffusion/configs/train_diffusion.json +++ b/models/brats_mri_generative_diffusion/configs/train_diffusion.json @@ -1,17 +1,12 @@ { "ckpt_dir": "$@bundle_root + '/models'", - "train_batch_size": 4, - "lr": 1e-05, - "train_patch_size": [ - 144, - 176, - 112 - ], + "train_batch_size": 5, + "lr": 5e-05, "latent_shape": [ "@latent_channels", - 36, - 44, - 28 + 48, + 48, + 32 ], "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", @@ -22,7 +17,7 @@ "in_channels": "@latent_channels", "out_channels": "@latent_channels", "num_channels": [ - 256, + 128, 256, 512 ], @@ -33,10 +28,11 @@ ], "num_head_channels": [ 0, - 64, - 64 + 32, + 32 ], - "num_res_blocks": 2 + "num_res_blocks": 2, + "use_flash_attention": true }, "diffusion": "$@network_def.to(@device)", "optimizer": { @@ -48,8 +44,7 @@ "_target_": "torch.optim.lr_scheduler.MultiStepLR", "optimizer": "@optimizer", "milestones": [ - 100, - 1000 + 2000 ], "gamma": 0.1 }, @@ -73,26 +68,15 @@ "scheduler": "@noise_scheduler", "scale_factor": "@scale_factor" }, - "crop_transforms": [ - { - "_target_": "CenterSpatialCropd", - "keys": "image", - "roi_size": "@train_patch_size" - } - ], "preprocessing": { "_target_": "Compose", - "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" + "transforms": "$@preprocessing_transforms" }, "dataset": { - "_target_": "monai.apps.DecathlonDataset", - "root_dir": "@dataset_dir", - "task": "Task01_BrainTumour", - "section": "training", - "cache_rate": 1.0, - "num_workers": 8, - "download": false, - "transform": "@train#preprocessing" + "_target_": "CacheDataset", + "data": "@train_datalist", + "transform": "@train#preprocessing", + "cache_rate": 1.0 }, "dataloader": { "_target_": "DataLoader", @@ -142,7 +126,8 @@ "latent_shape": "@latent_shape", "inferer": "@train#inferer", "key_train_metric": "$None", - "train_handlers": "@train#handlers" + "train_handlers": "@train#handlers", + "amp": "@amp" } }, "initialize": [ diff --git a/models/brats_mri_generative_diffusion/docs/README.md b/models/brats_mri_generative_diffusion/docs/README.md index 1c01d861..81500cbb 100644 --- a/models/brats_mri_generative_diffusion/docs/README.md +++ b/models/brats_mri_generative_diffusion/docs/README.md @@ -1,11 +1,11 @@ # Model Overview A pre-trained model for volumetric (3D) Brats MRI 3D Latent Diffusion Generative Model. -This model is trained on BraTS 2016 and 2017 data from [Medical Decathlon](http://medicaldecathlon.com/), using the Latent diffusion model [1]. +This model is trained based on BraTS 2018 data from [Multimodal Brain Tumor Segmentation Challenge (BraTS) 2018](https://www.med.upenn.edu/sbia/brats2018.html), using the Latent diffusion model [1]. ![model workflow](https://developer.download.nvidia.com/assets/Clara/Images/monai_brain_image_gen_ldm3d_network.png) -This model is a generator for creating images like the Flair MRIs based on BraTS 2016 and 2017 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output. The `train_autoencoder.json` file describes the training process of the variational autoencoder with GAN loss. The `train_diffusion.json` file describes the training process of the 3D latent diffusion model. +This model is a generator for creating images like the T1CE MRIs based on BraTS 2018 data. It was trained as a 3d latent diffusion model and accepts Gaussian random noise as inputs to produce an image output. The `train_autoencoder.json` file describes the training process of the variational autoencoder with GAN loss. The `train_diffusion.json` file describes the training process of the 3D latent diffusion model. In this bundle, the autoencoder uses perceptual loss, which is based on ResNet50 with pre-trained weights (the network is frozen and will not be trained in the bundle). In default, the `pretrained` parameter is specified as `False` in `train_autoencoder.json`. To ensure correct training, changing the default settings is necessary. There are two ways to utilize pretrained weights: 1. if set `pretrained` to `True`, ImageNet pretrained weights from [torchvision](https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#ResNet50_Weights) will be used. However, the weights are for non-commercial use only. @@ -20,60 +20,72 @@ An example result from inference is shown below: **This is a demonstration network meant to just show the training process for this sort of network with MONAI. To achieve better performance, users need to use larger dataset like [Brats 2021](https://www.synapse.org/#!Synapse:syn25829067/wiki/610865) and have GPU with memory larger than 32G to enable larger networks and attention layers.** ## Data -The training data is BraTS 2016 and 2017 from the Medical Segmentation Decathalon. Users can find more details on the dataset (`Task01_BrainTumour`) at http://medicaldecathlon.com/. +The training data is from the [Multimodal Brain Tumor Segmentation Challenge (BraTS) 2018](https://www.med.upenn.edu/sbia/brats2018.html). - Target: Image Generation - Task: Synthesis - Modality: MRI -- Size: 388 3D volumes (1 channel used) +- Size: 285 3D volumes (1 channel used) + +The provided labelled data was partitioned, based on our own split, into training (200 studies), validation (42 studies) and testing (43 studies) datasets. + +### Preprocessing +The data list/split can be created with the script `scripts/prepare_datalist.py`. + +``` +python scripts/prepare_datalist.py --path your-brats18-dataset-path +``` ## Training Configuration +We need to install the required packages. +``` +pip install git+https://github.com/Project-MONAI/GenerativeModels.git +pip install lpips +pip install xformers +``` + If you have a GPU with less than 32G of memory, you may need to decrease the batch size when training. To do so, modify the `train_batch_size` parameter in the [configs/train_autoencoder.json](../configs/train_autoencoder.json) and [configs/train_diffusion.json](../configs/train_diffusion.json) configuration files. ### Training Configuration of Autoencoder The autoencoder was trained using the following configuration: - GPU: at least 32GB GPU memory -- Actual Model Input: 112 x 128 x 80 +- Actual Model Input: 112 x 112 x 80 - AMP: False - Optimizer: Adam -- Learning Rate: 1e-5 +- Learning Rate: 5e-5 - Loss: L1 loss, perceptual loss, KL divergence loss, adversarial loss, GAN BCE loss #### Input -1 channel 3D MRI Flair patches +1 channel 3D MRI T1CE patches #### Output - 1 channel 3D MRI reconstructed patches -- 8 channel mean of latent features -- 8 channel standard deviation of latent features +- 4 channel mean of latent features +- 4 channel standard deviation of latent features ### Training Configuration of Diffusion Model The latent diffusion model was trained using the following configuration: - GPU: at least 32GB GPU memory -- Actual Model Input: 36 x 44 x 28 +- Actual Model Input: 48 x 48 x 32 - AMP: False - Optimizer: Adam -- Learning Rate: 1e-5 +- Learning Rate: 5e-5 - Loss: MSE loss #### Training Input -- 8 channel noisy latent features +- 4 channel noisy latent features - a long int that indicates the time step #### Training Output -8 channel predicted added noise +4 channel predicted added noise #### Inference Input -8 channel noise +4 channel noise #### Inference Output -8 channel denoised latent features - -### Memory Consumption Warning - -If you face memory issues with data loading, you can lower the caching rate `cache_rate` in the configurations within range [0, 1] to minimize the System RAM requirements. +4 channel denoised latent features ## Performance @@ -96,7 +108,7 @@ For more details usage instructions, visit the [MONAI Bundle Configuration Page] python -m monai.bundle run --config_file configs/train_autoencoder.json ``` -Please note that if the default dataset path is not modified with the actual path (it should be the path that contains `Task01_BrainTumour`) in the bundle config files, you can also override it by using `--dataset_dir`: +Please note that if the default dataset path is not modified with the actual path in the bundle config files, you can also override it by using `--dataset_dir`: ``` python -m monai.bundle run --config_file configs/train_autoencoder.json --dataset_dir @@ -106,7 +118,7 @@ python -m monai.bundle run --config_file configs/train_autoencoder.json --datase To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs. ``` -torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/multi_gpu_train_autoencoder.json']" --lr 8e-5 +torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/multi_gpu_train_autoencoder.json']" --lr 2e-4 ``` #### Check the Autoencoder Training result @@ -134,7 +146,7 @@ python -m monai.bundle run --config_file "['configs/train_autoencoder.json','con To train with multiple GPUs, use the following command, which requires scaling up the learning rate according to the number of GPUs. ``` -torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json','configs/multi_gpu_train_autoencoder.json','configs/multi_gpu_train_diffusion.json']" --lr 8e-5 +torchrun --standalone --nnodes=1 --nproc_per_node=8 -m monai.bundle run --config_file "['configs/train_autoencoder.json','configs/train_diffusion.json','configs/multi_gpu_train_autoencoder.json','configs/multi_gpu_train_diffusion.json']" --lr 2e-4 ``` #### Execute inference diff --git a/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py b/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py index c1a21bfa..a7bbd6c8 100644 --- a/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py +++ b/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py @@ -81,6 +81,7 @@ class VaeGanTrainer(Trainer): `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. @@ -118,6 +119,7 @@ def __init__( additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Sequence | None = None, + amp: bool = False, decollate: bool = True, optim_set_to_none: bool = False, to_kwargs: dict | None = None, @@ -139,6 +141,7 @@ def __init__( additional_metrics=additional_metrics, metric_cmp_fn=metric_cmp_fn, handlers=train_handlers, + amp=amp, postprocessing=postprocessing, decollate=decollate, to_kwargs=to_kwargs, diff --git a/models/brats_mri_generative_diffusion/scripts/losses.py b/models/brats_mri_generative_diffusion/scripts/losses.py index 43536067..ec95ce77 100644 --- a/models/brats_mri_generative_diffusion/scripts/losses.py +++ b/models/brats_mri_generative_diffusion/scripts/losses.py @@ -15,7 +15,7 @@ adv_loss = PatchAdversarialLoss(criterion="least_squares") adv_weight = 0.1 -perceptual_weight = 0.1 +perceptual_weight = 0.3 # kl_weight: important hyper-parameter. # If too large, decoder cannot recon good results from latent space. # If too small, latent space will not be regularized enough for the diffusion model diff --git a/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py b/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py new file mode 100644 index 00000000..e48edbb9 --- /dev/null +++ b/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py @@ -0,0 +1,72 @@ +import argparse +import glob +import json +import os + +import monai +from sklearn.model_selection import train_test_split + + +def produce_sample_dict(line: str): + names = os.listdir(line) + seg, t1ce, t1, t2, flair = [], [], [], [], [] + for name in names: + name = os.path.join(line, name) + if "_seg.nii" in name: + seg.append(name) + elif "_t1ce.nii" in name: + t1ce.append(name) + elif "_t1.nii" in name: + t1.append(name) + elif "_t2.nii" in name: + t2.append(name) + elif "_flair.nii" in name: + flair.append(name) + + return {"label": seg[0], "image": t1ce + t1 + t2 + flair} + + +def produce_datalist(dataset_dir: str, train_size: int = 200): + """ + This function is used to split the dataset. + It will produce "train_size" number of samples for training, and the other samples + are divided equally into val and test sets. + """ + + samples = sorted(glob.glob(os.path.join(dataset_dir, "*", "*"), recursive=True)) + datalist = [] + for line in samples: + datalist.append(produce_sample_dict(line)) + train_list, other_list = train_test_split(datalist, train_size=train_size) + val_list, test_list = train_test_split(other_list, train_size=0.5) + + return {"training": train_list, "validation": val_list, "testing": test_list} + + +def main(args): + """ + split the dataset and output the data list into a json file. + """ + data_file_base_dir = os.path.join(os.path.abspath(args.path), "training") + # produce deterministic data splits + monai.utils.set_determinism(seed=123) + datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size) + with open(args.output, "w") as f: + json.dump(datalist, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--path", + type=str, + default="/workspace/data/medical/brats2018challenge", + help="root path of brats 2018 dataset.", + ) + parser.add_argument( + "--output", type=str, default="configs/datalist.json", help="relative path of output datalist json file." + ) + parser.add_argument("--train_size", type=int, default=200, help="number of training samples.") + args = parser.parse_args() + + main(args)