57
57
"""
58
58
59
59
import argparse
60
- import numpy as np
61
60
import os
62
61
import sys
63
62
import time
89
88
RandSpatialCropd ,
90
89
Spacingd ,
91
90
ToDeviced ,
92
- EnsureTyped ,
93
- EnsureType ,
94
91
)
95
92
from monai .utils import set_determinism
96
93
@@ -110,16 +107,16 @@ def __call__(self, data):
110
107
for key in self .keys :
111
108
result = []
112
109
# merge label 2 and label 3 to construct TC
113
- result .append (np .logical_or (d [key ] == 2 , d [key ] == 3 ))
110
+ result .append (torch .logical_or (d [key ] == 2 , d [key ] == 3 ))
114
111
# merge labels 1, 2 and 3 to construct WT
115
112
result .append (
116
- np .logical_or (
117
- np .logical_or (d [key ] == 2 , d [key ] == 3 ), d [key ] == 1
113
+ torch .logical_or (
114
+ torch .logical_or (d [key ] == 2 , d [key ] == 3 ), d [key ] == 1
118
115
)
119
116
)
120
117
# label 2 is ET
121
118
result .append (d [key ] == 2 )
122
- d [key ] = np .stack (result , axis = 0 ). astype ( np . float32 )
119
+ d [key ] = torch .stack (result , dim = 0 )
123
120
return d
124
121
125
122
@@ -132,7 +129,7 @@ def __init__(
132
129
self ,
133
130
root_dir ,
134
131
section ,
135
- transform = LoadImaged ([ "image" , "label" ]) ,
132
+ transform = None ,
136
133
cache_rate = 1.0 ,
137
134
num_workers = 0 ,
138
135
shuffle = False ,
@@ -187,6 +184,7 @@ def main_worker(args):
187
184
[
188
185
# load 4 Nifti images and stack them together
189
186
LoadImaged (keys = ["image" , "label" ]),
187
+ ToDeviced (keys = ["image" , "label" ], device = device ),
190
188
EnsureChannelFirstd (keys = "image" ),
191
189
ConvertToMultiChannelBasedOnBratsClassesd (keys = "label" ),
192
190
Orientationd (keys = ["image" , "label" ], axcodes = "RAS" ),
@@ -195,8 +193,6 @@ def main_worker(args):
195
193
pixdim = (1.0 , 1.0 , 1.0 ),
196
194
mode = ("bilinear" , "nearest" ),
197
195
),
198
- EnsureTyped (keys = ["image" , "label" ]),
199
- ToDeviced (keys = ["image" , "label" ], device = device ),
200
196
RandSpatialCropd (keys = ["image" , "label" ], roi_size = [224 , 224 , 144 ], random_size = False ),
201
197
RandFlipd (keys = ["image" , "label" ], prob = 0.5 , spatial_axis = 0 ),
202
198
RandFlipd (keys = ["image" , "label" ], prob = 0.5 , spatial_axis = 1 ),
@@ -223,6 +219,7 @@ def main_worker(args):
223
219
val_transforms = Compose (
224
220
[
225
221
LoadImaged (keys = ["image" , "label" ]),
222
+ ToDeviced (keys = ["image" , "label" ], device = device ),
226
223
EnsureChannelFirstd (keys = "image" ),
227
224
ConvertToMultiChannelBasedOnBratsClassesd (keys = "label" ),
228
225
Orientationd (keys = ["image" , "label" ], axcodes = "RAS" ),
@@ -232,8 +229,6 @@ def main_worker(args):
232
229
mode = ("bilinear" , "nearest" ),
233
230
),
234
231
NormalizeIntensityd (keys = "image" , nonzero = True , channel_wise = True ),
235
- EnsureTyped (keys = ["image" , "label" ]),
236
- ToDeviced (keys = ["image" , "label" ], device = device ),
237
232
]
238
233
)
239
234
val_ds = BratsCacheDataset (
@@ -283,7 +278,7 @@ def main_worker(args):
283
278
dice_metric = DiceMetric (include_background = True , reduction = "mean" )
284
279
dice_metric_batch = DiceMetric (include_background = True , reduction = "mean_batch" )
285
280
286
- post_trans = Compose ([EnsureType (), Activations (sigmoid = True ), AsDiscrete (threshold = 0.5 )])
281
+ post_trans = Compose ([Activations (sigmoid = True ), AsDiscrete (threshold = 0.5 )])
287
282
288
283
# start a typical PyTorch training
289
284
best_metric = - 1
0 commit comments