24
24
RandRotated ,
25
25
RandScaleIntensityd ,
26
26
RandShiftIntensityd ,
27
+ RandCropByPosNegLabeld ,
28
+ RandZoomd ,
29
+ RandFlipd ,
30
+ RandRotate90d ,
31
+ MapTransform ,
27
32
)
33
+ from monai .transforms .utility .dictionary import ApplyTransformToPointsd
34
+ from monai .transforms .spatial .dictionary import ConvertBoxToPointsd , ConvertPointsToBoxesd
28
35
from monai .apps .detection .transforms .dictionary import (
29
36
AffineBoxToImageCoordinated ,
30
37
AffineBoxToWorldCoordinated ,
31
38
BoxToMaskd ,
32
39
ClipBoxToImaged ,
33
40
ConvertBoxToStandardModed ,
34
41
MaskToBoxd ,
35
- RandCropBoxByPosNegLabeld ,
36
- RandFlipBoxd ,
37
- RandRotateBox90d ,
38
- RandZoomBoxd ,
39
42
ConvertBoxModed ,
40
43
StandardizeEmptyBoxd ,
41
44
)
45
+ from monai .config import KeysCollection
46
+ from monai .utils .type_conversion import convert_data_type
47
+ from monai .data .box_utils import clip_boxes_to_image
48
+ from monai .apps .detection .transforms .box_ops import convert_box_to_mask
42
49
43
50
44
51
def generate_detection_train_transform (
@@ -49,6 +56,7 @@ def generate_detection_train_transform(
49
56
intensity_transform ,
50
57
patch_size ,
51
58
batch_size ,
59
+ point_key = "points" ,
52
60
affine_lps_to_ras = False ,
53
61
amp = True ,
54
62
):
@@ -59,6 +67,7 @@ def generate_detection_train_transform(
59
67
image_key: the key to represent images in the input json files
60
68
box_key: the key to represent boxes in the input json files
61
69
label_key: the key to represent box labels in the input json files
70
+ point_key: the key to represent points to save the box coordinates
62
71
gt_box_mode: ground truth box mode in the input json files
63
72
intensity_transform: transform to scale image intensities,
64
73
usually ScaleIntensityRanged for CT images, and NormalizeIntensityd for MR images.
@@ -87,67 +96,67 @@ def generate_detection_train_transform(
87
96
intensity_transform ,
88
97
EnsureTyped (keys = [image_key ], dtype = torch .float16 ),
89
98
ConvertBoxToStandardModed (box_keys = [box_key ], mode = gt_box_mode ),
99
+ ConvertBoxToPointsd (keys = [box_key ]),
90
100
AffineBoxToImageCoordinated (
91
101
box_keys = [box_key ],
92
102
box_ref_image_keys = image_key ,
93
103
image_meta_key_postfix = "meta_dict" ,
94
104
affine_lps_to_ras = affine_lps_to_ras ,
95
105
),
96
- RandCropBoxByPosNegLabeld (
97
- image_keys = [ image_key ],
98
- box_keys = box_key ,
99
- label_keys = label_key ,
106
+ # generate box mask based on the input boxes which used for cropping
107
+ GenerateExtendedBoxMask (
108
+ keys = box_key ,
109
+ image_key = image_key ,
100
110
spatial_size = patch_size ,
101
111
whole_box = True ,
112
+ ),
113
+ RandCropByPosNegLabeld (
114
+ keys = [image_key ],
115
+ label_key = "mask_image" ,
116
+ spatial_size = patch_size ,
102
117
num_samples = batch_size ,
103
118
pos = 1 ,
104
119
neg = 1 ,
105
120
),
106
- RandZoomBoxd (
107
- image_keys = [image_key ],
108
- box_keys = [box_key ],
109
- box_ref_image_keys = [image_key ],
121
+ RandZoomd (
122
+ keys = [image_key ],
110
123
prob = 0.2 ,
111
124
min_zoom = 0.7 ,
112
125
max_zoom = 1.4 ,
113
126
padding_mode = "constant" ,
114
127
keep_size = True ,
115
128
),
116
- ClipBoxToImaged (
117
- box_keys = box_key ,
118
- label_keys = [label_key ],
119
- box_ref_image_keys = image_key ,
120
- remove_empty = True ,
121
- ),
122
- RandFlipBoxd (
123
- image_keys = [image_key ],
124
- box_keys = [box_key ],
125
- box_ref_image_keys = [image_key ],
129
+ RandFlipd (
130
+ keys = [image_key ],
126
131
prob = 0.5 ,
127
132
spatial_axis = 0 ,
128
133
),
129
- RandFlipBoxd (
130
- image_keys = [image_key ],
131
- box_keys = [box_key ],
132
- box_ref_image_keys = [image_key ],
134
+ RandFlipd (
135
+ keys = [image_key ],
133
136
prob = 0.5 ,
134
137
spatial_axis = 1 ,
135
138
),
136
- RandFlipBoxd (
137
- image_keys = [image_key ],
138
- box_keys = [box_key ],
139
- box_ref_image_keys = [image_key ],
139
+ RandFlipd (
140
+ keys = [image_key ],
140
141
prob = 0.5 ,
141
142
spatial_axis = 2 ,
142
143
),
143
- RandRotateBox90d (
144
- image_keys = [image_key ],
145
- box_keys = [box_key ],
146
- box_ref_image_keys = [image_key ],
144
+ RandRotate90d (
145
+ keys = [image_key ],
147
146
prob = 0.75 ,
148
147
max_k = 3 ,
149
148
spatial_axes = (0 , 1 ),
150
149
),
150
+ # apply the same affine matrix which already applied on the images to the points
151
+ ApplyTransformToPointsd (keys = [point_key ], refer_key = image_key , affine_lps_to_ras = affine_lps_to_ras ),
152
+ # convert points back to boxes
153
+ ConvertPointsToBoxesd (keys = [point_key ]),
154
+ ClipBoxToImaged (
155
+ box_keys = box_key ,
156
+ label_keys = [label_key ],
157
+ box_ref_image_keys = image_key ,
158
+ remove_empty = True ,
159
+ ),
151
160
BoxToMaskd (
152
161
box_keys = [box_key ],
153
162
label_keys = [label_key ],
@@ -184,7 +193,7 @@ def generate_detection_train_transform(
184
193
RandScaleIntensityd (keys = [image_key ], prob = 0.15 , factors = 0.25 ),
185
194
RandShiftIntensityd (keys = [image_key ], prob = 0.15 , offsets = 0.1 ),
186
195
RandAdjustContrastd (keys = [image_key ], prob = 0.3 , gamma = (0.7 , 1.5 )),
187
- EnsureTyped (keys = [image_key , box_key ], dtype = compute_dtype ),
196
+ EnsureTyped (keys = [image_key ], dtype = compute_dtype ),
188
197
EnsureTyped (keys = [label_key ], dtype = torch .long ),
189
198
]
190
199
)
@@ -307,3 +316,73 @@ def generate_detection_inference_transform(
307
316
]
308
317
)
309
318
return test_transforms , post_transforms
319
+
320
+
321
+ class GenerateExtendedBoxMask (MapTransform ):
322
+ """
323
+ Generate box mask based on the input boxes.
324
+ """
325
+
326
+ def __init__ (
327
+ self ,
328
+ keys : KeysCollection ,
329
+ image_key : str ,
330
+ spatial_size : tuple [int , int , int ],
331
+ whole_box : bool ,
332
+ mask_image_key : str = "mask_image" ,
333
+ ) -> None :
334
+ """
335
+ Args:
336
+ keys: keys of the corresponding items to be transformed.
337
+ image_key: key for the image data in the dictionary.
338
+ spatial_size: size of the spatial dimensions of the mask.
339
+ whole_box: whether to use the whole box for generating the mask.
340
+ mask_image_key: key to store the generated box mask.
341
+ """
342
+ super ().__init__ (keys )
343
+ self .image_key = image_key
344
+ self .spatial_size = spatial_size
345
+ self .whole_box = whole_box
346
+ self .mask_image_key = mask_image_key
347
+
348
+ def generate_fg_center_boxes_np (self , boxes , image_size , whole_box = True ):
349
+ # We don't require crop center to be within the boxes.
350
+ # As along as the cropped patch contains a box, it is considered as a foreground patch.
351
+ # Positions within extended_boxes are crop centers for foreground patches
352
+ spatial_dims = len (image_size )
353
+ boxes_np , * _ = convert_data_type (boxes , np .ndarray )
354
+
355
+ extended_boxes = np .zeros_like (boxes_np , dtype = int )
356
+ boxes_start = np .ceil (boxes_np [:, :spatial_dims ]).astype (int )
357
+ boxes_stop = np .floor (boxes_np [:, spatial_dims :]).astype (int )
358
+ for axis in range (spatial_dims ):
359
+ if not whole_box :
360
+ extended_boxes [:, axis ] = boxes_start [:, axis ] - self .spatial_size [axis ] // 2 + 1
361
+ extended_boxes [:, axis + spatial_dims ] = boxes_stop [:, axis ] + self .spatial_size [axis ] // 2 - 1
362
+ else :
363
+ # extended box start
364
+ extended_boxes [:, axis ] = boxes_stop [:, axis ] - self .spatial_size [axis ] // 2 - 1
365
+ extended_boxes [:, axis ] = np .minimum (extended_boxes [:, axis ], boxes_start [:, axis ])
366
+ # extended box stop
367
+ extended_boxes [:, axis + spatial_dims ] = extended_boxes [:, axis ] + self .spatial_size [axis ] // 2
368
+ extended_boxes [:, axis + spatial_dims ] = np .maximum (
369
+ extended_boxes [:, axis + spatial_dims ], boxes_stop [:, axis ]
370
+ )
371
+ extended_boxes , _ = clip_boxes_to_image (extended_boxes , image_size , remove_empty = True ) # type: ignore
372
+ return extended_boxes
373
+
374
+ def generate_mask_img (self , boxes , image_size , whole_box = True ):
375
+ extended_boxes_np = self .generate_fg_center_boxes_np (boxes , image_size , whole_box )
376
+ mask_img = convert_box_to_mask (
377
+ extended_boxes_np , np .ones (extended_boxes_np .shape [0 ]), image_size , bg_label = 0 , ellipse_mask = False
378
+ )
379
+ mask_img = np .amax (mask_img , axis = 0 , keepdims = True )[0 :1 , ...]
380
+ return mask_img
381
+
382
+ def __call__ (self , data ):
383
+ d = dict (data )
384
+ for key in self .key_iterator (d ):
385
+ image = d [self .image_key ]
386
+ boxes = d [key ]
387
+ data [self .mask_image_key ] = self .generate_mask_img (boxes , image .shape [1 :], whole_box = self .whole_box )
388
+ return data
0 commit comments