-
Notifications
You must be signed in to change notification settings - Fork 1.1k
v0.5 to v0.6 migration guide
In MONAI v0.6, we enhanced the design of metrics
and postprocessing
transforms to provide more flexible and advanced features, which, in the meantime, brought some slightly breaking changes.
Check What's new in 0.6 for more details of the new features.
To help users smoothly migrate the existing code from MONAI v0.5 to v0.6, this document shows the detailed steps with example code.
- After model forward and loss backward, to independently apply postprocessing transforms for every single data in a batch, need to execute
decollate_batch
to convert the batch Tensor to a list of Tensors. - Currently, all the MONAI postprocessing transforms are updated to handle
channel-first
Tensor instead ofbatch-first
Tensor. So both the preprocessing transforms and postprocessing transforms handle the same data shape. Just execute postprocessing transform for every items of the list. - As all the postprocessing transforms expect Tensor type input, in order to ensure the data after
decollate_batch
is Tensor, suggest to addEnsureType
orEnsureTyped
transform. - Use
from_engine()
utility to extract expected data from the decollated list, setfirst=True
for scalar values. - Code examples:
(1) If you are using the array
based postprocessing transforms, a v0.5 classification program can be:
pred_trans = Activations(softmax=True)
label_trans = AsDiscrete(to_onehot=True, n_classes=5)
pred = model(image)
pred = pred_trans(pred)
label = label_trans(label)
metric(y_pred=pred, y=label)
And the corresponding code of v0.6 can be:
from monai.data import decollate_batch
pred_trans = Compose([EnsureType(), Activations(softmax=True)])
label_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=5)])
pred = model(image)
pred = [pred_trans(i) for i in decollate_batch(pred)]
label = [label_trans(i) for i in decollate_batch(label)]
metric(y_pred=pred, y=label)
(2) If you are using the dictionary
based postprocessing transforms, a v0.5 classification program can be:
postprocessing = Compose([
Activations(keys="pred", softmax=True),
AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])
data["pred"] = model(data["image"])
data = postprocessing(data)
metric(y_pred=data["pred"], y=data["label"])
And the corresponding code of v0.6 can be:
from monai.data import decollate_batch
from monai.handlers import from_engine
postprocessing = Compose([
Activations(keys="pred", softmax=True),
AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])
data["pred"] = model(data["image"])
# decollate data into a list of dictionaries
data = [postprocessing(i) for in decollate_batch(data)]
# extract the `pred` and `label` to compute metric
pred, label = from_engine(["pred", "label"])(data)
metric(y_pred=pred, y=label)
For more detailed tutorial of decollate_batch
, please check: decollate_batch tutorial.
- Support both
batch-first
Tensor and list ofchannel-first
Tensors as input. - Support data parallel in multi-GPUs or multi-nodes cases.
- Example code of a validation during the training of segmentation task
A typical code example of v0.5:
dice_metric = DiceMetric(include_background=True, reduction="mean")
metric_sum = 0.0
metric_count = 0
for val_data in val_loader:
images, labels = val_data["img"].to(device), val_data["seg"].to(device)
preds = val_post_tran(sliding_window_inference(images, (96, 96, 96), 4, model))
value, not_nans = dice_metric(y_pred=preds, y=labels)
metric_count += not_nans.item()
metric_sum += value.item() * not_nans.item()
metric = metric_sum / metric_count
And the corresponding code of v0.6 can be:
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
for val_data in val_loader:
images, labels = val_data["img"].to(device), val_data["seg"].to(device)
preds = sliding_window_inference(val_images, (96, 96, 96), 4, model)
# decollate prediction into a list and execute post processing for every item
preds = [postprocessing(i) for i in decollate_batch(preds)]
# compute metric for current iteration
dice_metric(y_pred=val_outputs, y=val_labels)
# aggregate and compute the final result of metric
metric = dice_metric.aggregate().item()
dice_metric.reset()
The batch_transform
and output_transform
args of v0.5 can be:
StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
SegmentationSaver(
output_dir=root_dir,
batch_transform=lambda batch: batch["image_meta_dict"],
output_transform=lambda output: output["pred"],
)
And the corresponding code of v0.6 can be:
from monai.handlers import from_engine
StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
SegmentationSaver(
output_dir=root_dir,
batch_transform=from_engine("image_meta_dict"),
output_transform=from_engine("pred"),
),
Some args of post_transform
changed to postprocessing
in v0.6, for example, the arg of SupervisedTrainer
:
trainer = SupervisedTrainer(
device=device,
max_epochs=5,
train_data_loader=train_loader,
network=net,
optimizer=opt,
loss_function=loss,
inferer=SimpleInferer(),
postprocessing=train_postprocessing,
key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
train_handlers=train_handlers,
)
In v0.6, DynUNet
has been updated, the previous version is still made available:
from monai.networks.nets.dynunet_v1 import DynUNetV1 as DynUNet
dynunet_v1
will be removed in the future release.