Skip to content

Commit b105849

Browse files
authored
Update distributed training for MetaTensor (Project-MONAI#777)
* [DLMED] update brats ddp Signed-off-by: Nic Ma <[email protected]> * [DLMED] restore dist for compatible Signed-off-by: Nic Ma <[email protected]> * [DLMED] update ddp examples Signed-off-by: Nic Ma <[email protected]> * [DLMED] update ignite workflows Signed-off-by: Nic Ma <[email protected]> * [DLMED] update smart cache ddp Signed-off-by: Nic Ma <[email protected]> * [DLMED] update horovod ddp Signed-off-by: Nic Ma <[email protected]>
1 parent b429c88 commit b105849

8 files changed

+15
-34
lines changed

Diff for: acceleration/distributed_training/brats_training_ddp.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
"""
5858

5959
import argparse
60-
import numpy as np
6160
import os
6261
import sys
6362
import time
@@ -89,8 +88,6 @@
8988
RandSpatialCropd,
9089
Spacingd,
9190
ToDeviced,
92-
EnsureTyped,
93-
EnsureType,
9491
)
9592
from monai.utils import set_determinism
9693

@@ -110,16 +107,16 @@ def __call__(self, data):
110107
for key in self.keys:
111108
result = []
112109
# merge label 2 and label 3 to construct TC
113-
result.append(np.logical_or(d[key] == 2, d[key] == 3))
110+
result.append(torch.logical_or(d[key] == 2, d[key] == 3))
114111
# merge labels 1, 2 and 3 to construct WT
115112
result.append(
116-
np.logical_or(
117-
np.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
113+
torch.logical_or(
114+
torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
118115
)
119116
)
120117
# label 2 is ET
121118
result.append(d[key] == 2)
122-
d[key] = np.stack(result, axis=0).astype(np.float32)
119+
d[key] = torch.stack(result, dim=0)
123120
return d
124121

125122

@@ -132,7 +129,7 @@ def __init__(
132129
self,
133130
root_dir,
134131
section,
135-
transform=LoadImaged(["image", "label"]),
132+
transform=None,
136133
cache_rate=1.0,
137134
num_workers=0,
138135
shuffle=False,
@@ -187,6 +184,7 @@ def main_worker(args):
187184
[
188185
# load 4 Nifti images and stack them together
189186
LoadImaged(keys=["image", "label"]),
187+
ToDeviced(keys=["image", "label"], device=device),
190188
EnsureChannelFirstd(keys="image"),
191189
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
192190
Orientationd(keys=["image", "label"], axcodes="RAS"),
@@ -195,8 +193,6 @@ def main_worker(args):
195193
pixdim=(1.0, 1.0, 1.0),
196194
mode=("bilinear", "nearest"),
197195
),
198-
EnsureTyped(keys=["image", "label"]),
199-
ToDeviced(keys=["image", "label"], device=device),
200196
RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
201197
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
202198
RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
@@ -223,6 +219,7 @@ def main_worker(args):
223219
val_transforms = Compose(
224220
[
225221
LoadImaged(keys=["image", "label"]),
222+
ToDeviced(keys=["image", "label"], device=device),
226223
EnsureChannelFirstd(keys="image"),
227224
ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
228225
Orientationd(keys=["image", "label"], axcodes="RAS"),
@@ -232,8 +229,6 @@ def main_worker(args):
232229
mode=("bilinear", "nearest"),
233230
),
234231
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
235-
EnsureTyped(keys=["image", "label"]),
236-
ToDeviced(keys=["image", "label"], device=device),
237232
]
238233
)
239234
val_ds = BratsCacheDataset(
@@ -283,7 +278,7 @@ def main_worker(args):
283278
dice_metric = DiceMetric(include_background=True, reduction="mean")
284279
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")
285280

286-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
281+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
287282

288283
# start a typical PyTorch training
289284
best_metric = -1

Diff for: acceleration/distributed_training/unet_evaluation_ddp.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from monai.data import DataLoader, Dataset, create_test_image_3d, DistributedSampler, decollate_batch
6363
from monai.inferers import sliding_window_inference
6464
from monai.metrics import DiceMetric
65-
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, ScaleIntensityd, EnsureTyped, EnsureType
65+
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, ScaleIntensityd
6666

6767

6868
def evaluate(args):
@@ -92,7 +92,6 @@ def evaluate(args):
9292
LoadImaged(keys=["img", "seg"]),
9393
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
9494
ScaleIntensityd(keys="img"),
95-
EnsureTyped(keys=["img", "seg"]),
9695
]
9796
)
9897

@@ -103,7 +102,7 @@ def evaluate(args):
103102
# sliding window inference need to input 1 image in every iteration
104103
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True, sampler=val_sampler)
105104
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
106-
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
105+
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
107106
# create UNet, DiceLoss and Adam optimizer
108107
device = torch.device(f"cuda:{args.local_rank}")
109108
torch.cuda.set_device(device)

Diff for: acceleration/distributed_training/unet_evaluation_horovod.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
Example script to execute this program, only need to run on the master node:
3636
`horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python unet_evaluation_horovod.py -d "./testdata"`
3737
38-
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3], [horovod 0.19.5].
38+
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3], [horovod 0.25.0].
3939
4040
Referring to: https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py
4141
@@ -56,7 +56,7 @@
5656
from monai.data import DataLoader, Dataset, create_test_image_3d, decollate_batch
5757
from monai.inferers import sliding_window_inference
5858
from monai.metrics import DiceMetric
59-
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, ScaleIntensityd, EnsureTyped, EnsureType
59+
from monai.transforms import Activations, AsChannelFirstd, AsDiscrete, Compose, LoadImaged, ScaleIntensityd, EnsureType
6060

6161

6262
def evaluate(args):
@@ -88,7 +88,6 @@ def evaluate(args):
8888
LoadImaged(keys=["img", "seg"]),
8989
AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
9090
ScaleIntensityd(keys="img"),
91-
EnsureTyped(keys=["img", "seg"]),
9291
]
9392
)
9493

@@ -156,7 +155,7 @@ def main():
156155
evaluate(args=args)
157156

158157

159-
# Example script to execute this program only on the master node:
158+
# Example script to execute this program on 4 nodes (only need to run below command on the master node):
160159
# horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python unet_evaluation_horovod.py -d "./testdata"
161160
if __name__ == "__main__":
162161
main()

Diff for: acceleration/distributed_training/unet_evaluation_workflows.py

-2
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979
KeepLargestConnectedComponentd,
8080
LoadImaged,
8181
ScaleIntensityd,
82-
EnsureTyped,
8382
SaveImaged,
8483
)
8584

@@ -113,7 +112,6 @@ def evaluate(args):
113112
LoadImaged(keys=["image", "label"]),
114113
AsChannelFirstd(keys=["image", "label"], channel_dim=-1),
115114
ScaleIntensityd(keys="image"),
116-
EnsureTyped(keys=["image", "label"]),
117115
]
118116
)
119117

Diff for: acceleration/distributed_training/unet_training_ddp.py

-2
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
RandCropByPosNegLabeld,
6868
RandRotate90d,
6969
ScaleIntensityd,
70-
EnsureTyped,
7170
)
7271

7372

@@ -106,7 +105,6 @@ def train(args):
106105
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
107106
),
108107
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
109-
EnsureTyped(keys=["img", "seg"]),
110108
]
111109
)
112110

Diff for: acceleration/distributed_training/unet_training_horovod.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
Example script to execute this program, only need to run on the master node:
4040
`horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python unet_training_horovod.py -d "./testdata"`
4141
42-
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3], [horovod 0.19.5].
42+
This example was tested with [Ubuntu 16.04/20.04], [NCCL 2.6.3], [horovod 0.25.0].
4343
4444
Referring to: https://github.com/horovod/horovod/blob/master/examples/pytorch_mnist.py
4545
@@ -66,7 +66,6 @@
6666
RandCropByPosNegLabeld,
6767
RandRotate90d,
6868
ScaleIntensityd,
69-
EnsureTyped,
7069
)
7170

7271

@@ -106,7 +105,6 @@ def train(args):
106105
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
107106
),
108107
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
109-
EnsureTyped(keys=["img", "seg"]),
110108
]
111109
)
112110

@@ -188,7 +186,7 @@ def main():
188186
train(args=args)
189187

190188

191-
# Example script to execute this program only on the master node:
189+
# Example script to execute this program on 4 nodes (only need to run below command on the master node):
192190
# horovodrun -np 16 -H server1:4,server2:4,server3:4,server4:4 python unet_training_horovod.py -d "./testdata"
193191
if __name__ == "__main__":
194192
main()

Diff for: acceleration/distributed_training/unet_training_smartcache.py

-2
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
RandCropByPosNegLabeld,
7474
RandRotate90d,
7575
ScaleIntensityd,
76-
EnsureTyped,
7776
)
7877

7978

@@ -112,7 +111,6 @@ def train(args):
112111
keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
113112
),
114113
RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
115-
EnsureTyped(keys=["img", "seg"]),
116114
]
117115
)
118116

Diff for: acceleration/distributed_training/unet_training_workflows.py

-4
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@
8181
RandCropByPosNegLabeld,
8282
RandRotate90d,
8383
ScaleIntensityd,
84-
EnsureTyped,
8584
)
8685

8786

@@ -118,7 +117,6 @@ def train(args):
118117
keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
119118
),
120119
RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]),
121-
EnsureTyped(keys=["image", "label"]),
122120
]
123121
)
124122

@@ -155,7 +153,6 @@ def train(args):
155153

156154
train_post_transforms = Compose(
157155
[
158-
EnsureTyped(keys="pred"),
159156
Activationsd(keys="pred", sigmoid=True),
160157
AsDiscreted(keys="pred", threshold=0.5),
161158
KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]),
@@ -198,7 +195,6 @@ def main():
198195
train(args=args)
199196

200197

201-
202198
# python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_PER_NODE
203199
# --nnodes=NUM_NODES --node_rank=INDEX_CURRENT_NODE
204200
# --master_addr="192.168.1.1" --master_port=1234

0 commit comments

Comments
 (0)