Skip to content

Commit

Permalink
improved MONAI based training
Browse files Browse the repository at this point in the history
- using EpochMetric for calculating validation loss after every epoch

- transforms
  - scaling intensity
  - randomly scaling intensity while training
  - thresholding net outputs and keeping largest island only
  • Loading branch information
che85 committed Jun 22, 2021
1 parent d1bc0f7 commit a1d5b47
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 56 deletions.
195 changes: 139 additions & 56 deletions MONAI/MONAI_based_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
"import monai\n",
"\n",
"from monai.engines import SupervisedEvaluator, SupervisedTrainer\n",
"from monai.handlers.iteration_metric import IterationMetric\n",
"from monai.metrics.utils import do_metric_reduction\n",
"\n",
"# from ignite.metrics.epoch_metric import EpochMetric\n",
"from monai.handlers import (\n",
" ValidationHandler,\n",
" CheckpointSaver,\n",
Expand All @@ -26,6 +30,7 @@
" GarbageCollector,\n",
" EarlyStopHandler\n",
")\n",
"\n",
"from monai.inferers import SimpleInferer\n",
"from monai.transforms import (\n",
" AddChanneld,\n",
Expand All @@ -34,21 +39,27 @@
" RandAffined,\n",
" ToTensord,\n",
" RandAdjustContrastd,\n",
" ScaleIntensityRangePercentilesd\n",
" ScaleIntensityd,\n",
" RandScaleIntensityd,\n",
" ScaleIntensityRangePercentilesd,\n",
" AsDiscreted, \n",
" KeepLargestConnectedComponentd\n",
")\n",
"\n",
"from monai.engines.utils import CommonKeys as Keys\n",
"from monai.data import DataLoader\n",
"from monai.engines.utils import IterationEvents\n",
"from ignite.contrib.handlers import ProgressBar\n",
"\n",
"from ignite.contrib.handlers import ProgressBar\n",
"from ignite.engine import Events\n",
"\n",
"# Local imports\n",
"from utils import generate_directory_name, get_list_of_file_names\n",
"from loss import DicePlusConstantCatCrossEntropyLoss\n",
"from optimizer import RAdam\n",
"from transforms import DistanceTransformd, OneHotTransformd\n",
"from model import VNet\n",
"from metric import LossMetric\n",
"\n",
"# monai.config.print_config()\n",
"\n",
Expand Down Expand Up @@ -147,31 +158,27 @@
" [\n",
" LoadImaged(keys=all_keys, reader=\"NibabelReader\"),\n",
" AddChanneld(keys=all_keys),\n",
" OneHotTransformd(keys=[\"labels\"]),\n",
" RandAffined(\n",
" keys=all_keys, \n",
" prob=0.5, \n",
" rotate_range=(-rot_rad,rot_rad), # radians! \n",
" translate_range=(-20,20),\n",
" scale_range=(-0.1,0.1), \n",
" translate_range=(-30,30),\n",
" scale_range=(-0.2,0.2), \n",
" mode=\"nearest\", \n",
" padding_mode=\"zeros\", \n",
" as_tensor_output=False\n",
" ),\n",
" RandAdjustContrastd(\n",
" keys=[\"mid-systolic-images\"], \n",
" prob=0.5, \n",
" gamma=(0.5, 1.5)\n",
" ),\n",
" OneHotTransformd(keys=[\"labels\"]),\n",
" DistanceTransformd(keys=[\"annuli\"]),\n",
" ScaleIntensityRangePercentilesd(\n",
" RandScaleIntensityd(\n",
" keys=[\"mid-systolic-images\"],\n",
" factors=0.3,\n",
" prob=0.5\n",
" ),\n",
" ScaleIntensityd(\n",
" keys=[\"mid-systolic-images\"], \n",
" lower=2, \n",
" upper=99, \n",
" b_min=0.0, \n",
" b_max=1.0, \n",
" clip=True, \n",
" relative=True\n",
" minv=0.0,\n",
" maxv=1.0\n",
" ),\n",
" ToTensord(keys=all_keys)\n",
" ]\n",
Expand All @@ -182,17 +189,30 @@
" AddChanneld(keys=all_keys),\n",
" OneHotTransformd(keys=[\"labels\"]),\n",
" DistanceTransformd(keys=[\"annuli\"]),\n",
" ScaleIntensityRangePercentilesd(\n",
"# ScaleIntensityRangePercentilesd(\n",
"# keys=[\"mid-systolic-images\"], \n",
"# lower=2, \n",
"# upper=99, \n",
"# b_min=0.0, \n",
"# b_max=1.0, \n",
"# clip=True, \n",
"# relative=True\n",
"# ),\n",
" ScaleIntensityd(\n",
" keys=[\"mid-systolic-images\"], \n",
" lower=2, \n",
" upper=99, \n",
" b_min=0.0, \n",
" b_max=1.0, \n",
" clip=True, \n",
" relative=True\n",
" minv=0.0,\n",
" maxv=1.0\n",
" ),\n",
" ToTensord(keys=all_keys)\n",
" ]\n",
")\n",
"\n",
"\n",
"post_transforms = Compose(\n",
" [\n",
" AsDiscreted(keys=Keys.PRED, threshold_values=True),\n",
" KeepLargestConnectedComponentd(keys=Keys.PRED, applied_labels=[1,2,3]),\n",
" ]\n",
")"
]
},
Expand Down Expand Up @@ -296,7 +316,7 @@
"# NB: as of now using local copy of RAdam optimizer but seems to be integrated into pytorch soon (https://github.com/pytorch/pytorch/pull/58968)\n",
"optimizer = RAdam(\n",
" params=trainable_params, \n",
" lr=0.05, \n",
" lr=0.02, \n",
" weight_decay=1e-05\n",
")\n",
"\n",
Expand All @@ -305,10 +325,9 @@
"lr_scheduler = ReduceLROnPlateau(\n",
" optimizer=optimizer,\n",
" factor=0.5,\n",
" patience=5,\n",
" patience=3,\n",
" verbose=True,\n",
" mode='max',\n",
" min_lr=0.001\n",
" mode='min'\n",
")"
]
},
Expand All @@ -331,6 +350,51 @@
"# optimizer.load_state_dict(checkpoint['optimizer'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class SupervisedValidator(SupervisedEvaluator):\n",
"\n",
" def __init__(self, **kwargs):\n",
" self.loss_function = kwargs.pop(\"loss_function\")\n",
" super(SupervisedValidator, self).__init__(**kwargs)\n",
" \n",
" def _iteration(self, engine, batchdata):\n",
" if batchdata is None:\n",
" raise ValueError(\"Must provide batch data for current iteration.\")\n",
" batch = self.prepare_batch(batchdata, engine.state.device, engine.non_blocking)\n",
" if len(batch) == 2:\n",
" inputs, targets = batch\n",
" args: Tuple = ()\n",
" kwargs: Dict = {}\n",
" else:\n",
" inputs, targets, args, kwargs = batch\n",
"\n",
" # put iteration outputs into engine.state\n",
" engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}\n",
" # execute forward computation\n",
" with self.mode(self.network):\n",
" if self.amp:\n",
" with torch.cuda.amp.autocast():\n",
" predictions = self.inferer(inputs, self.network, *args, **kwargs)\n",
" loss = self.loss_function(predictions, targets).mean()\n",
" else:\n",
" predictions = self.inferer(inputs, self.network, *args, **kwargs)\n",
" loss = self.loss_function(predictions, targets).mean()\n",
" \n",
" engine.state.output[Keys.PRED] = predictions\n",
" engine.state.output[\"val_loss\"] = loss.item()\n",
" engine.fire_event(IterationEvents.FORWARD_COMPLETED)\n",
" engine.fire_event(IterationEvents.MODEL_COMPLETED)\n",
" \n",
" torch.cuda.empty_cache()\n",
" \n",
" return engine.state.output\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -348,16 +412,10 @@
"source": [
"val_handlers = [\n",
" ProgressBar(),\n",
" TensorBoardImageHandler(\n",
" log_dir=val_log_dir,\n",
" batch_transform=lambda x: (x[Keys.IMAGE], x[Keys.LABEL]),\n",
" output_transform=lambda x: x[Keys.PRED],\n",
" max_channels=10\n",
" ),\n",
" GarbageCollector(\"epoch\")\n",
"]\n",
"\n",
" \n",
"\n",
"validator = SupervisedEvaluator(\n",
" device=device,\n",
" val_data_loader=val_data_loader,\n",
Expand All @@ -368,11 +426,13 @@
" \"val_mean_dice\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL]))\n",
" },\n",
" additional_metrics={\n",
" \"val_mean_dice_anterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),\n",
" \"val_mean_dice_posterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),\n",
" \"val_mean_dice_septal\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))\n",
" \"val_loss\": LossMetric(metric_fn=loss_function, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])),\n",
"# \"val_mean_dice_anterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),\n",
"# \"val_mean_dice_posterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),\n",
"# \"val_mean_dice_septal\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))\n",
" },\n",
" val_handlers=val_handlers,\n",
" post_transform=post_transforms,\n",
" amp=use_amp\n",
")\n",
"\n",
Expand All @@ -383,20 +443,10 @@
" interval=1, \n",
" epoch_level=True\n",
" ),\n",
" LrScheduleHandler(\n",
" lr_scheduler=lr_scheduler, \n",
" print_lr=True, \n",
" step_transform=lambda x: x.state.metrics[\"train_mean_dice\"]\n",
" ),\n",
" StatsHandler(\n",
" tag_name=\"train_loss\", \n",
" output_transform=lambda x: x[Keys.LOSS]\n",
" ),\n",
" TensorBoardStatsHandler(\n",
" log_dir=train_log_dir, \n",
" tag_name=\"train_loss\", \n",
" output_transform=lambda x: x[Keys.LOSS]\n",
" ),\n",
" CheckpointSaver(\n",
" save_dir=checkpoint_dir, \n",
" save_dict={\"net\": net, \"opt\": optimizer}, \n",
Expand All @@ -417,35 +467,63 @@
" loss_function=loss_function,\n",
" inferer=SimpleInferer(),\n",
" key_train_metric={\n",
" \"train_mean_dice\": MeanDice(include_background=False,\n",
" output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])),\n",
" \"train_mean_dice\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED], x[Keys.LABEL])),\n",
" },\n",
" additional_metrics={\n",
" \"train_mean_dice_anterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),\n",
" \"train_mean_dice_posterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),\n",
" \"train_mean_dice_septal\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))\n",
"# \"train_mean_dice_anterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,1,:], x[Keys.LABEL][:,1,:])),\n",
"# \"train_mean_dice_posterior\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,2,:], x[Keys.LABEL][:,2,:])),\n",
"# \"train_mean_dice_septal\": MeanDice(include_background=False, output_transform=lambda x: (x[Keys.PRED][:,3,:], x[Keys.LABEL][:,3,:]))\n",
" },\n",
" train_handlers=train_handlers,\n",
" post_transform=post_transforms,\n",
" amp=use_amp\n",
")\n",
"\n",
"\n",
"# more handlers\n",
"StatsHandler(\n",
" output_transform=lambda x: None,\n",
" global_epoch_transform=lambda x: trainer.state.epoch,\n",
" global_epoch_transform=lambda x: trainer.state.epoch\n",
").attach(validator)\n",
"\n",
"\n",
"LrScheduleHandler(\n",
" lr_scheduler=lr_scheduler, \n",
" print_lr=True, \n",
" step_transform=lambda x: x.state.metrics[\"val_loss\"]\n",
").attach(validator)\n",
"\n",
"\n",
"TensorBoardStatsHandler(\n",
" log_dir=val_log_dir,\n",
" output_transform=lambda x: None,\n",
" global_epoch_transform=lambda x: trainer.state.epoch,\n",
" global_epoch_transform=lambda x: trainer.state.epoch\n",
").attach(validator)\n",
"\n",
"\n",
"# add handler to draw the first image and the corresponding label and model output in the last batch\n",
"# here we draw the 3D output as GIF format along Depth axis, at every validation epoch\n",
"val_tensorboard_image_handler = TensorBoardImageHandler(\n",
" log_dir=val_log_dir,\n",
" batch_transform=lambda x: (x[Keys.IMAGE], x[Keys.LABEL]),\n",
" output_transform=lambda x: x[Keys.PRED],\n",
" max_channels=10,\n",
" global_iter_transform=lambda x: trainer.state.epoch,\n",
")\n",
"validator.add_event_handler(\n",
" event_name=Events.EPOCH_COMPLETED, handler=val_tensorboard_image_handler\n",
")\n",
"\n",
"TensorBoardStatsHandler(\n",
" log_dir=train_log_dir, \n",
" tag_name=\"train_loss\", \n",
" output_transform=lambda x: x[Keys.LOSS],\n",
" global_epoch_transform=lambda x: trainer.state.iteration\n",
").attach(trainer)\n",
"\n",
"\n",
"EarlyStopHandler(\n",
" patience=20,\n",
" patience=30,\n",
" score_function=lambda x: x.state.metrics[\"val_mean_dice\"],\n",
" trainer=trainer,\n",
" epoch_level=True,\n",
Expand Down Expand Up @@ -562,6 +640,11 @@
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down
23 changes: 23 additions & 0 deletions MONAI/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from monai.handlers.iteration_metric import IterationMetric


class LossMetric(IterationMetric):

def __init__(
self,
metric_fn,
output_transform=lambda x: x,
device="cpu",
save_details=True,
):
super().__init__(
metric_fn=metric_fn,
output_transform=output_transform,
device=device,
save_details=save_details,
)

def compute(self):
return torch.Tensor(self._scores).mean()

0 comments on commit a1d5b47

Please sign in to comment.