Skip to content

Commit

Permalink
change to Dataset as in brain segmentation bundle, add support for am…
Browse files Browse the repository at this point in the history
…p, add validate in train_autoencoder.json

Signed-off-by: Can-Zhao <[email protected]>
  • Loading branch information
Can-Zhao committed Dec 13, 2023
1 parent f9334c1 commit aea7312
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 29 deletions.
135 changes: 108 additions & 27 deletions models/brats_mri_generative_diffusion/configs/train_autoencoder.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,29 @@
"imports": [
"$import functools",
"$import glob",
"$import scripts"
"$import scripts",
"$import generative"
],
"bundle_root": ".",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"ckpt_dir": "$@bundle_root + '/models'",
"tf_dir": "$@bundle_root + '/eval'",
"dataset_dir": "/workspace/data/medical",
"data_list_file_path": "$@bundle_root + '/configs/datalist.json'",
"dataset_dir": "/datasets/brats18",
"train_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='training', base_dir=@dataset_dir)",
"val_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, data_list_key='validation', base_dir=@dataset_dir)",
"pretrained": false,
"perceptual_loss_model_weights_path": null,
"train_batch_size": 2,
"lr": 1e-05,
"val_batch_size": 2,
"epochs": 4000,
"val_interval": 10,
"lr": 1e-04,
"amp": true,
"train_patch_size": [
112,
128,
80
128,
128
],
"channel": 0,
"spacing": [
Expand All @@ -26,7 +34,7 @@
],
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"latent_channels": 4,
"discriminator_def": {
"_target_": "generative.networks.nets.PatchDiscriminator",
"spatial_dims": "@spatial_dims",
Expand Down Expand Up @@ -56,7 +64,9 @@
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
"with_decoder_nonlocal_attn": false,
"use_checkpointing": true,
"use_convtranspose": false
},
"perceptual_loss_def": {
"_target_": "generative.losses.PerceptualLoss",
Expand Down Expand Up @@ -140,13 +150,8 @@
"transforms": "$@preprocessing_transforms + @train#crop_transforms + @final_transforms"
},
"dataset": {
"_target_": "monai.apps.DecathlonDataset",
"root_dir": "@dataset_dir",
"task": "Task01_BrainTumour",
"section": "training",
"cache_rate": 1.0,
"num_workers": 8,
"download": false,
"_target_": "Dataset",
"data": "@train_datalist",
"transform": "@train#preprocessing"
},
"dataloader": {
Expand All @@ -158,32 +163,33 @@
},
"handlers": [
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"model": "@gnetwork"
},
"save_interval": 0,
"save_final": true,
"_target_": "ValidationHandler",
"validator": "@validate#evaluator",
"epoch_level": true,
"final_filename": "model_autoencoder.pt"
"interval": "@val_interval"
},
{
"_target_": "StatsHandler",
"tag_name": "train_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]+monai.handlers.from_engine(['d_loss'], first=True)(x)[0]"
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_loss",
"tag_name": "train_generator_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['g_loss'], first=True)(x)[0]"
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"tag_name": "train_discriminator_loss",
"output_transform": "$lambda x: monai.handlers.from_engine(['d_loss'], first=True)(x)[0]"
}
],
"trainer": {
"_target_": "scripts.ldm_trainer.VaeGanTrainer",
"device": "@device",
"max_epochs": 1500,
"max_epochs": "@epochs",
"train_data_loader": "@train#dataloader",
"g_network": "@gnetwork",
"g_optimizer": "@goptimizer",
Expand All @@ -195,7 +201,82 @@
"g_update_latents": true,
"latent_shape": "@latent_channels",
"key_train_metric": "$None",
"train_handlers": "@train#handlers"
"train_handlers": "@train#handlers",
"amp": "@amp"
}
},
"validate": {
"crop_transforms": [
{
"_target_": "DivisiblePadd",
"keys": "image",
"k": 16
}
],
"preprocessing": {
"_target_": "Compose",
"transforms": "$@preprocessing_transforms + @validate#crop_transforms + @final_transforms"
},
"dataset": {
"_target_": "Dataset",
"data": "@val_datalist",
"transform": "@validate#preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@validate#dataset",
"batch_size": "@val_batch_size",
"shuffle": false,
"num_workers": 4
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Lambdad",
"keys": "pred",
"func": "$lambda x: x[0]"
}
]
},
"handlers": [
{
"_target_": "StatsHandler",
"iteration_log": false
},
{
"_target_": "TensorBoardStatsHandler",
"log_dir": "@tf_dir",
"iteration_log": false
},
{
"_target_": "CheckpointSaver",
"save_dir": "@ckpt_dir",
"save_dict": {
"model": "@gnetwork"
},
"save_interval": 0,
"save_final": true,
"epoch_level": true,
"final_filename": "model_autoencoder.pt"
}
],
"key_metric": {
"val_mean_l2": {
"_target_": "MeanSquaredError",
"output_transform": "$monai.handlers.from_engine(['pred', 'image'])"
}
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"device": "@device",
"val_data_loader": "@validate#dataloader",
"network": "@gnetwork",
"postprocessing": "@validate#postprocessing",
"key_val_metric": "$@validate#key_metric",
"metric_cmp_fn": "$lambda current_metric,prev_best: current_metric < prev_best",
"val_handlers": "@validate#handlers",
"amp": "@amp"
}
},
"initialize": [
Expand All @@ -204,4 +285,4 @@
"run": [
"$@train#trainer.run()"
]
}
}
8 changes: 6 additions & 2 deletions models/brats_mri_generative_diffusion/scripts/ldm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence

import torch
from monai.config import IgniteInfo
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
from monai.utils import min_version, optional_import
from monai.utils import GanKeys, min_version, optional_import
from monai.utils.enums import CommonKeys, GanKeys
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -81,6 +81,7 @@ class VaeGanTrainer(Trainer):
`best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, etc.
amp: whether to enable auto-mixed-precision training, default is False.
decollate: whether to decollate the batch-first data to a list of data after model computation,
recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
default to `True`.
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(
additional_metrics: dict[str, Metric] | None = None,
metric_cmp_fn: Callable = default_metric_cmp_fn,
train_handlers: Sequence | None = None,
amp: bool = False,
decollate: bool = True,
optim_set_to_none: bool = False,
to_kwargs: dict | None = None,
Expand All @@ -139,6 +141,7 @@ def __init__(
additional_metrics=additional_metrics,
metric_cmp_fn=metric_cmp_fn,
handlers=train_handlers,
amp=amp,
postprocessing=postprocessing,
decollate=decollate,
to_kwargs=to_kwargs,
Expand Down Expand Up @@ -173,6 +176,7 @@ def _iteration(
raise ValueError("must provide batch data for current iteration.")

d_input = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)[0]
batch_size = engine.data_loader.batch_size # type: ignore
g_input = d_input
g_output, z_mu, z_sigma = engine.g_inferer(g_input, engine.g_network)

Expand Down
72 changes: 72 additions & 0 deletions models/brats_mri_generative_diffusion/scripts/prepare_datalist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import argparse
import glob
import json
import os

import monai
from sklearn.model_selection import train_test_split


def produce_sample_dict(line: str):
names = os.listdir(line)
seg, t1ce, t1, t2, flair = [], [], [], [], []
for name in names:
name = os.path.join(line, name)
if "_seg.nii" in name:
seg.append(name)
elif "_t1ce.nii" in name:
t1ce.append(name)
elif "_t1.nii" in name:
t1.append(name)
elif "_t2.nii" in name:
t2.append(name)
elif "_flair.nii" in name:
flair.append(name)

return {"label": seg[0], "image": t1ce + t1 + t2 + flair}


def produce_datalist(dataset_dir: str, train_size: int = 200):
"""
This function is used to split the dataset.
It will produce "train_size" number of samples for training, and the other samples
are divided equally into val and test sets.
"""

samples = sorted(glob.glob(os.path.join(dataset_dir, "*", "*"), recursive=True))
datalist = []
for line in samples:
datalist.append(produce_sample_dict(line))
train_list, other_list = train_test_split(datalist, train_size=train_size)
val_list, test_list = train_test_split(other_list, train_size=0.5)

return {"training": train_list, "validation": val_list, "testing": test_list}


def main(args):
"""
split the dataset and output the data list into a json file.
"""
data_file_base_dir = os.path.join(os.path.abspath(args.path), "training")
# produce deterministic data splits
monai.utils.set_determinism(seed=123)
datalist = produce_datalist(dataset_dir=data_file_base_dir, train_size=args.train_size)
with open(args.output, "w") as f:
json.dump(datalist, f)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--path",
type=str,
default="/workspace/data/medical/brats2018challenge",
help="root path of brats 2018 dataset.",
)
parser.add_argument(
"--output", type=str, default="configs/datalist.json", help="relative path of output datalist json file."
)
parser.add_argument("--train_size", type=int, default=200, help="number of training samples.")
args = parser.parse_args()

main(args)

0 comments on commit aea7312

Please sign in to comment.