Skip to content

v0.5 to v0.6 migration guide

Nic Ma edited this page Jul 9, 2021 · 13 revisions

Migrating your v0.5 code to v0.6

In MONAI v0.6, we enhanced the design of metrics and postprocessing transforms to provide more flexible and advanced features, which, in the meantime, brought some slightly breaking changes.

Check What's new in 0.6 for more details of the new features.

To help users smoothly migrate the existing code from MONAI v0.5 to v0.6, this document shows the detailed steps with example code.

Decollate batch-first Tensor to list of channel-first Tensors

  1. After model forward and loss backward, to independently apply postprocessing transforms for every single data in a batch, need to execute decollate_batch to convert the batch Tensor to a list of Tensors.
  2. Currently, all the MONAI postprocessing transforms are updated to handle channel-first Tensor instead of batch-first Tensor. So both the preprocessing transforms and postprocessing transforms handle the same data shape. Just execute postprocessing transform for every items of the list.
  3. As all the postprocessing transforms expect Tensor type input, in order to ensure the data after decollate_batch is Tensor, suggest to add EnsureType or EnsureTyped transform.
  4. Use from_engine() utility to extract expected data from the decollated list, set first=True for scalar values.
  5. Code examples:

(1) If you are using the array based postprocessing transforms, a v0.5 classification program can be:

pred_trans = Activations(softmax=True)
label_trans = AsDiscrete(to_onehot=True, n_classes=5)

pred = model(image)
pred = pred_trans(pred)
label = label_trans(label)

metric(y_pred=pred, y=label)

And the corresponding code of v0.6 can be:

from monai.data import decollate_batch

pred_trans = Compose([EnsureType(), Activations(softmax=True)])
label_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=5)])

pred = model(image)
pred = [pred_trans(i) for i in decollate_batch(pred)]
label = [label_trans(i) for i in decollate_batch(label)]

metric(y_pred=pred, y=label)

(2) If you are using the dictionary based postprocessing transforms, a v0.5 classification program can be:

postprocessing = Compose([
    Activations(keys="pred", softmax=True),
    AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])

data["pred"] = model(data["image"])
data = postprocessing(data)

metric(y_pred=data["pred"], y=data["label"])

And the corresponding code of v0.6 can be:

from monai.data import decollate_batch
from monai.handlers import from_engine

postprocessing = Compose([
    Activations(keys="pred", softmax=True),
    AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])

data["pred"] = model(data["image"])
# decollate data into a list of dictionaries
data = [postprocessing(i) for in decollate_batch(data)]

# extract the `pred` and `label` to compute metric
pred, label = from_engine(["pred", "label"])(data)
metric(y_pred=pred, y=label)

For more detailed tutorial of decollate_batch, please check: decollate_batch tutorial.

Adjust the new metrics APIs to automatically support data parallel

  1. Support both batch-first Tensor and list of channel-first Tensors as input.
  2. Support data parallel in multi-GPUs or multi-nodes cases.
  3. Example code of a validation during the training of segmentation task

A typical code example of v0.5:

dice_metric = DiceMetric(include_background=True, reduction="mean")

metric_sum = 0.0
metric_count = 0
for val_data in val_loader:
    images, labels = val_data["img"].to(device), val_data["seg"].to(device)
    preds = val_post_tran(sliding_window_inference(images, (96, 96, 96), 4, model))
    value, not_nans = dice_metric(y_pred=preds, y=labels)
    metric_count += not_nans.item()
    metric_sum += value.item() * not_nans.item()
metric = metric_sum / metric_count

And the corresponding code of v0.6 can be:

dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

for val_data in val_loader:
    images, labels = val_data["img"].to(device), val_data["seg"].to(device)
    preds = sliding_window_inference(val_images, (96, 96, 96), 4, model)
    # decollate prediction into a list and execute post processing for every item
    preds = [postprocessing(i) for i in decollate_batch(preds)]
    # compute metric for current iteration
    dice_metric(y_pred=val_outputs, y=val_labels)

# aggregate and compute the final result of metric
metric = dice_metric.aggregate().item()
dice_metric.reset()

Update the batch_transform or output_transform of several event handlers

The batch_transform and output_transform args of v0.5 can be:

StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
SegmentationSaver(
    output_dir=root_dir,
    batch_transform=lambda batch: batch["image_meta_dict"],
    output_transform=lambda output: output["pred"],
)

And the corresponding code of v0.6 can be:

from monai.handlers import from_engine

StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
SegmentationSaver(
    output_dir=root_dir,
    batch_transform=from_engine("image_meta_dict"),
    output_transform=from_engine("pred"),
),

Update all the post transform to postprocessing

Some args of post_transform changed to postprocessing in v0.6, for example, the arg of SupervisedTrainer:

trainer = SupervisedTrainer(
    device=device,
    max_epochs=5,
    train_data_loader=train_loader,
    network=net,
    optimizer=opt,
    loss_function=loss,
    inferer=SimpleInferer(),
    postprocessing=train_postprocessing,
    key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
    train_handlers=train_handlers,
)

DynUNet

In v0.6, DynUNet has been updated, the previous version is still made available:

from monai.networks.nets.dynunet_v1 import DynUNetV1 as DynUNet

dynunet_v1 will be removed in the future release.

Clone this wiki locally