From aea7312905dedd0b88ba8229b1a02a9bd37793f5 Mon Sep 17 00:00:00 2001 From: Can-Zhao Date: Wed, 13 Dec 2023 13:16:18 -0800 Subject: [PATCH] change to Dataset as in brain segmentation bundle, add support for amp, add validate in train_autoencoder.json Signed-off-by: Can-Zhao --- .../configs/train_autoencoder.json | 135 ++++++++++++++---- .../scripts/ldm_trainer.py | 8 +- .../scripts/prepare_datalist.py | 72 ++++++++++ 3 files changed, 186 insertions(+), 29 deletions(-) create mode 100644 models/brats_mri_generative_diffusion/scripts/prepare_datalist.py diff --git a/models/brats_mri_generative_diffusion/configs/train_autoencoder.json b/models/brats_mri_generative_diffusion/configs/train_autoencoder.json index ddaa8c41..e7be6686 100644 --- a/models/brats_mri_generative_diffusion/configs/train_autoencoder.json +++ b/models/brats_mri_generative_diffusion/configs/train_autoencoder.json @@ -2,21 +2,29 @@ "imports": [ "$import functools", "$import glob", - "$import scripts" + "$import scripts", + "$import generative" ], "bundle_root": ".", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "ckpt_dir": "$@bundle_root + '/models'", "tf_dir": "$@bundle_root + '/eval'", - "dataset_dir": "/workspace/data/medical", + "data_list_file_path": "$@bundle_root + '/configs/datalist.json'", + "dataset_dir": "/datasets/brats18", + "train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training', base_dir=@dataset_dir)", + "val_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation', base_dir=@dataset_dir)", "pretrained": false, "perceptual_loss_model_weights_path": null, "train_batch_size": 2, - "lr": 1e-05, + "val_batch_size": 2, + "epochs": 4000, + "val_interval": 10, + "lr": 1e-04, + "amp": true, "train_patch_size": [ - 112, 128, - 80 + 128, + 128 ], "channel": 0, "spacing": [ @@ -26,7 +34,7 @@ ], "spatial_dims": 3, "image_channels": 1, - "latent_channels": 8, + "latent_channels": 4, "discriminator_def": { "_target_": "generative.networks.nets.PatchDiscriminator", "spatial_dims": "@spatial_dims", @@ -56,7 +64,9 @@ false ], "with_encoder_nonlocal_attn": false, - "with_decoder_nonlocal_attn": false + "with_decoder_nonlocal_attn": false, + "use_checkpointing": true, + "use_convtranspose": false }, "perceptual_loss_def": { "_target_": "generative.losses.PerceptualLoss", @@ -140,13 +150,8 @@ "transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms" }, "dataset": { - "_target_": "monai.apps.DecathlonDataset", - "root_dir": "@dataset_dir", - "task": "Task01_BrainTumour", - "section": "training", - "cache_rate": 1.0, - "num_workers": 8, - "download": false, + "_target_": "Dataset", + "data": "@train_datalist", "transform": "@train#preprocessing" }, "dataloader": { @@ -158,32 +163,33 @@ }, "handlers": [ { - "_target_": "CheckpointSaver", - "save_dir": "@ckpt_dir", - "save_dict": { - "model": "@gnetwork" - }, - "save_interval": 0, - "save_final": true, + "_target_": "ValidationHandler", + "validator": "@validate#evaluator", "epoch_level": true, - "final_filename": "model_autoencoder.pt" + "interval": "@val_interval" }, { "_target_": "StatsHandler", "tag_name": "train_loss", - "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" + "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]+monai.handlers.from_engine(['d_loss'], first=True)(x)[0]" }, { "_target_": "TensorBoardStatsHandler", "log_dir": "@tf_dir", - "tag_name": "train_loss", + "tag_name": "train_generator_loss", "output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]" + }, + { + "_target_": "TensorBoardStatsHandler", + "log_dir": "@tf_dir", + "tag_name": "train_discriminator_loss", + "output_transform": "$lambda x: monai.handlers.from_engine(['d_loss'], first=True)(x)[0]" } ], "trainer": { "_target_": "scripts.ldm_trainer.VaeGanTrainer", "device": "@device", - "max_epochs": 1500, + "max_epochs": "@epochs", "train_data_loader": "@train#dataloader", "g_network": "@gnetwork", "g_optimizer": "@goptimizer", @@ -195,7 +201,82 @@ "g_update_latents": true, "latent_shape": "@latent_channels", "key_train_metric": "$None", - "train_handlers": "@train#handlers" + "train_handlers": "@train#handlers", + "amp": "@amp" + } + }, + "validate": { + "crop_transforms": [ + { + "_target_": "DivisiblePadd", + "keys": "image", + "k": 16 + } + ], + "preprocessing": { + "_target_": "Compose", + "transforms": "$@preprocessing_transforms + @validate#crop_transforms + @final_transforms" + }, + "dataset": { + "_target_": "Dataset", + "data": "@val_datalist", + "transform": "@validate#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@validate#dataset", + "batch_size": "@val_batch_size", + "shuffle": false, + "num_workers": 4 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Lambdad", + "keys": "pred", + "func": "$lambda x: x[0]" + } + ] + }, + "handlers": [ + { + "_target_": "StatsHandler", + "iteration_log": false + }, + { + "_target_": "TensorBoardStatsHandler", + "log_dir": "@tf_dir", + "iteration_log": false + }, + { + "_target_": "CheckpointSaver", + "save_dir": "@ckpt_dir", + "save_dict": { + "model": "@gnetwork" + }, + "save_interval": 0, + "save_final": true, + "epoch_level": true, + "final_filename": "model_autoencoder.pt" + } + ], + "key_metric": { + "val_mean_l2": { + "_target_": "MeanSquaredError", + "output_transform": "$monai.handlers.from_engine(['pred', 'image'])" + } + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@validate#dataloader", + "network": "@gnetwork", + "postprocessing": "@validate#postprocessing", + "key_val_metric": "$@validate#key_metric", + "metric_cmp_fn": "$lambda current_metric,prev_best: current_metric < prev_best", + "val_handlers": "@validate#handlers", + "amp": "@amp" } }, "initialize": [ @@ -204,4 +285,4 @@ "run": [ "$@train#trainer.run()" ] -} +} \ No newline at end of file diff --git a/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py b/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py index c1a21bfa..04952923 100644 --- a/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py +++ b/models/brats_mri_generative_diffusion/scripts/ldm_trainer.py @@ -11,14 +11,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence import torch from monai.config import IgniteInfo from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.inferers import Inferer, SimpleInferer from monai.transforms import Transform -from monai.utils import min_version, optional_import +from monai.utils import GanKeys, min_version, optional_import from monai.utils.enums import CommonKeys, GanKeys from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -81,6 +81,7 @@ class VaeGanTrainer(Trainer): `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. decollate: whether to decollate the batch-first data to a list of data after model computation, recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`. @@ -118,6 +119,7 @@ def __init__( additional_metrics: dict[str, Metric] | None = None, metric_cmp_fn: Callable = default_metric_cmp_fn, train_handlers: Sequence | None = None, + amp: bool = False, decollate: bool = True, optim_set_to_none: bool = False, to_kwargs: dict | None = None, @@ -139,6 +141,7 @@ def __init__( additional_metrics=additional_metrics, metric_cmp_fn=metric_cmp_fn, handlers=train_handlers, + amp=amp, postprocessing=postprocessing, decollate=decollate, to_kwargs=to_kwargs, @@ -173,6 +176,7 @@ def _iteration( raise ValueError("must provide batch data for current iteration.") d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)[0] + batch_size = engine.data_loader.batch_size # type: ignore g_input = d_input g_output, z_mu, z_sigma = engine.g_inferer(g_input, engine.g_network) diff --git a/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py b/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py new file mode 100644 index 00000000..e48edbb9 --- /dev/null +++ b/models/brats_mri_generative_diffusion/scripts/prepare_datalist.py @@ -0,0 +1,72 @@ +import argparse +import glob +import json +import os + +import monai +from sklearn.model_selection import train_test_split + + +def produce_sample_dict(line: str): + names = os.listdir(line) + seg, t1ce, t1, t2, flair = [], [], [], [], [] + for name in names: + name = os.path.join(line, name) + if "_seg.nii" in name: + seg.append(name) + elif "_t1ce.nii" in name: + t1ce.append(name) + elif "_t1.nii" in name: + t1.append(name) + elif "_t2.nii" in name: + t2.append(name) + elif "_flair.nii" in name: + flair.append(name) + + return {"label": seg[0], "image": t1ce + t1 + t2 + flair} + + +def produce_datalist(dataset_dir: str, train_size: int = 200): + """ + This function is used to split the dataset. + It will produce "train_size" number of samples for training, and the other samples + are divided equally into val and test sets. + """ + + samples = sorted(glob.glob(os.path.join(dataset_dir, "*", "*"), recursive=True)) + datalist = [] + for line in samples: + datalist.append(produce_sample_dict(line)) + train_list, other_list = train_test_split(datalist, train_size=train_size) + val_list, test_list = train_test_split(other_list, train_size=0.5) + + return {"training": train_list, "validation": val_list, "testing": test_list} + + +def main(args): + """ + split the dataset and output the data list into a json file. + """ + data_file_base_dir = os.path.join(os.path.abspath(args.path), "training") + # produce deterministic data splits + monai.utils.set_determinism(seed=123) + datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size) + with open(args.output, "w") as f: + json.dump(datalist, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--path", + type=str, + default="/workspace/data/medical/brats2018challenge", + help="root path of brats 2018 dataset.", + ) + parser.add_argument( + "--output", type=str, default="configs/datalist.json", help="relative path of output datalist json file." + ) + parser.add_argument("--train_size", type=int, default=200, help="number of training samples.") + args = parser.parse_args() + + main(args)