-
Notifications
You must be signed in to change notification settings - Fork 1.1k
v0.5 to v0.6 migration guide
In MONAI v0.6, we enhance the design of metrics
and postprocessing
transforms to provide flexible and advanced features, which, in the meantime, bring in breaking changes.
Please check out What's new in 0.6 for details of the new features.
To help smoothly migrate the existing code from MONAI v0.5 to v0.6, this document shows the detailed steps with code examples.
- After model forward and loss backward, to independently apply postprocessing transforms for every single data in a batch, need to execute
decollate_batch
to convert the batch Tensor to a list of Tensors. - Currently, all the MONAI postprocessing transforms are updated to handle
channel-first
Tensor instead ofbatch-first
Tensor. So both the preprocessing transforms and postprocessing transforms handle the same data shape. Just execute postprocessing transform for every items of the list. - As all the postprocessing transforms expect Tensor type input, in order to ensure the data after
decollate_batch
is Tensor, suggest to addEnsureType
orEnsureTyped
transform. - Use
from_engine()
utility to extract expected data from the decollated list, setfirst=True
for scalar values because for the scalar values which don't have batch dimension, we copied it to every item of the decollated list. - Code examples:
(1) If you are using the array
based postprocessing transforms, a v0.5 classification program can be:
pred_trans = Activations(softmax=True)
label_trans = AsDiscrete(to_onehot=True, n_classes=5)
pred = model(image)
pred = pred_trans(pred)
label = label_trans(label)
metric(y_pred=pred, y=label)
And the corresponding code of v0.6 can be:
from monai.data import decollate_batch
pred_trans = Compose([EnsureType(), Activations(softmax=True)])
label_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=5)])
pred = model(image)
pred = [pred_trans(i) for i in decollate_batch(pred)]
label = [label_trans(i) for i in decollate_batch(label)]
metric(y_pred=pred, y=label)
(2) If you are using the dictionary
based postprocessing transforms, a v0.5 classification program can be:
postprocessing = Compose([
Activations(keys="pred", softmax=True),
AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])
data["pred"] = model(data["image"])
data = postprocessing(data)
metric(y_pred=data["pred"], y=data["label"])
And the corresponding code of v0.6 can be:
from monai.data import decollate_batch
from monai.handlers import from_engine
postprocessing = Compose([
Activations(keys="pred", softmax=True),
AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])
data["pred"] = model(data["image"])
# decollate data into a list of dictionaries
data = [postprocessing(i) for in decollate_batch(data)]
# extract the `pred` and `label` to compute metric
pred, label = from_engine(["pred", "label"])(data)
metric(y_pred=pred, y=label)
For more detailed tutorial of decollate_batch
, please check: decollate_batch tutorial.
- Support both
batch-first
Tensor and list ofchannel-first
Tensors as input. - Support data parallel in multi-GPUs or multi-nodes cases.
- Example code of a validation during the training of segmentation task
A typical code example of v0.5:
dice_metric = DiceMetric(include_background=True, reduction="mean")
metric_sum = 0.0
metric_count = 0
for val_data in val_loader:
images, labels = val_data["img"].to(device), val_data["seg"].to(device)
preds = val_post_tran(sliding_window_inference(images, (96, 96, 96), 4, model))
value, not_nans = dice_metric(y_pred=preds, y=labels)
metric_count += not_nans.item()
metric_sum += value.item() * not_nans.item()
metric = metric_sum / metric_count
And the corresponding code of v0.6 can be:
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
for val_data in val_loader:
images, labels = val_data["img"].to(device), val_data["seg"].to(device)
preds = sliding_window_inference(val_images, (96, 96, 96), 4, model)
# decollate prediction into a list and execute post processing for every item
preds = [postprocessing(i) for i in decollate_batch(preds)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate and compute the final result of metric
metric = dice_metric.aggregate().item()
dice_metric.reset()
For more details about how to compute metrics in multi-processing, please check: compute metrics example.
The batch_transform
and output_transform
args of v0.5 can be:
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
SegmentationSaver(
output_dir=root_dir,
batch_transform=lambda batch: batch["image_meta_dict"],
output_transform=lambda output: output["pred"],
)
And the corresponding code of v0.6 can be:
from monai.handlers import from_engine
StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
SegmentationSaver(
output_dir=root_dir,
batch_transform=from_engine("image_meta_dict"),
output_transform=from_engine("pred"),
),
Some args of post_transform
changed to postprocessing
in v0.6, for example, the arg of SupervisedTrainer
:
trainer = SupervisedTrainer(
device=device,
max_epochs=5,
train_data_loader=train_loader,
network=net,
optimizer=opt,
loss_function=loss,
inferer=SimpleInferer(),
postprocessing=train_postprocessing,
key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
train_handlers=train_handlers,
)
In v0.6, DynUNet
has been updated, the previous version is still made available:
from monai.networks.nets.dynunet_v1 import DynUNetV1 as DynUNet
dynunet_v1
will be removed in the future release.
diff --git a/3d_segmentation/torch/unet_training_dict.py b/3d_segmentation/torch/unet_training_dict.py
index 0c85cbf..febecb4 100644
--- a/3d_segmentation/torch/unet_training_dict.py
+++ b/3d_segmentation/torch/unet_training_dict.py
@@ -22,7 +22,7 @@ from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import monai
-from monai.data import create_test_image_3d, list_data_collate
+from monai.data import create_test_image_3d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
@@ -34,7 +34,8 @@ from monai.transforms import (
RandCropByPosNegLabeld,
RandRotate90d,
ScaleIntensityd,
- ToTensord,
+ EnsureTyped,
+ EnsureType,
)
from monai.visualize import plot_2d_or_3d_image
@@ -69,7 +70,7 @@ def main(tempdir):
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
),
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
- ToTensord(keys=["img", "seg"]),
+ EnsureTyped(keys=["img", "seg"]),
]
)
val_transforms = Compose(
@@ -77,7 +78,7 @@ def main(tempdir):
LoadImaged(keys=["img", "seg"]),
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
ScaleIntensityd(keys="img"),
- ToTensord(keys=["img", "seg"]),
+ EnsureTyped(keys=["img", "seg"]),
]
)
@@ -102,8 +103,8 @@ def main(tempdir):
# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
- dice_metric = DiceMetric(include_background=True, reduction="mean")
- post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
+ dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
+ post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
@@ -149,8 +150,6 @@ def main(tempdir):
if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
- metric_sum = 0.0
- metric_count = 0
val_images = None
val_labels = None
val_outputs = None
@@ -159,11 +158,14 @@ def main(tempdir):
roi_size = (96, 96, 96)
sw_batch_size = 4
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
- val_outputs = post_trans(val_outputs)
- value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
- metric_count += len(value)
- metric_sum += value.item() * len(value)
- metric = metric_sum / metric_count
+ val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
+ # compute metric for current iteration
+ dice_metric(y_pred=val_outputs, y=val_labels)
+ # aggregate the final mean dice result
+ metric = dice_metric.aggregate().item()
+ # reset the status for next validation round
+ dice_metric.reset()
+
metric_values.append(metric)
if metric > best_metric:
best_metric = metric