Skip to content

Commit 0098607

Browse files
authored
Merge branch 'main' into ddpm_tutorials
2 parents ece5ad5 + cb81ed2 commit 0098607

File tree

3 files changed

+125
-41
lines changed

3 files changed

+125
-41
lines changed

detection/generate_transforms.py

Lines changed: 114 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,28 @@
2424
RandRotated,
2525
RandScaleIntensityd,
2626
RandShiftIntensityd,
27+
RandCropByPosNegLabeld,
28+
RandZoomd,
29+
RandFlipd,
30+
RandRotate90d,
31+
MapTransform,
2732
)
33+
from monai.transforms.utility.dictionary import ApplyTransformToPointsd
34+
from monai.transforms.spatial.dictionary import ConvertBoxToPointsd, ConvertPointsToBoxesd
2835
from monai.apps.detection.transforms.dictionary import (
2936
AffineBoxToImageCoordinated,
3037
AffineBoxToWorldCoordinated,
3138
BoxToMaskd,
3239
ClipBoxToImaged,
3340
ConvertBoxToStandardModed,
3441
MaskToBoxd,
35-
RandCropBoxByPosNegLabeld,
36-
RandFlipBoxd,
37-
RandRotateBox90d,
38-
RandZoomBoxd,
3942
ConvertBoxModed,
4043
StandardizeEmptyBoxd,
4144
)
45+
from monai.config import KeysCollection
46+
from monai.utils.type_conversion import convert_data_type
47+
from monai.data.box_utils import clip_boxes_to_image
48+
from monai.apps.detection.transforms.box_ops import convert_box_to_mask
4249

4350

4451
def generate_detection_train_transform(
@@ -49,6 +56,7 @@ def generate_detection_train_transform(
4956
intensity_transform,
5057
patch_size,
5158
batch_size,
59+
point_key="points",
5260
affine_lps_to_ras=False,
5361
amp=True,
5462
):
@@ -59,6 +67,7 @@ def generate_detection_train_transform(
5967
image_key: the key to represent images in the input json files
6068
box_key: the key to represent boxes in the input json files
6169
label_key: the key to represent box labels in the input json files
70+
point_key: the key to represent points to save the box coordinates
6271
gt_box_mode: ground truth box mode in the input json files
6372
intensity_transform: transform to scale image intensities,
6473
usually ScaleIntensityRanged for CT images, and NormalizeIntensityd for MR images.
@@ -87,67 +96,67 @@ def generate_detection_train_transform(
8796
intensity_transform,
8897
EnsureTyped(keys=[image_key], dtype=torch.float16),
8998
ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
99+
ConvertBoxToPointsd(keys=[box_key]),
90100
AffineBoxToImageCoordinated(
91101
box_keys=[box_key],
92102
box_ref_image_keys=image_key,
93103
image_meta_key_postfix="meta_dict",
94104
affine_lps_to_ras=affine_lps_to_ras,
95105
),
96-
RandCropBoxByPosNegLabeld(
97-
image_keys=[image_key],
98-
box_keys=box_key,
99-
label_keys=label_key,
106+
# generate box mask based on the input boxes which used for cropping
107+
GenerateExtendedBoxMask(
108+
keys=box_key,
109+
image_key=image_key,
100110
spatial_size=patch_size,
101111
whole_box=True,
112+
),
113+
RandCropByPosNegLabeld(
114+
keys=[image_key],
115+
label_key="mask_image",
116+
spatial_size=patch_size,
102117
num_samples=batch_size,
103118
pos=1,
104119
neg=1,
105120
),
106-
RandZoomBoxd(
107-
image_keys=[image_key],
108-
box_keys=[box_key],
109-
box_ref_image_keys=[image_key],
121+
RandZoomd(
122+
keys=[image_key],
110123
prob=0.2,
111124
min_zoom=0.7,
112125
max_zoom=1.4,
113126
padding_mode="constant",
114127
keep_size=True,
115128
),
116-
ClipBoxToImaged(
117-
box_keys=box_key,
118-
label_keys=[label_key],
119-
box_ref_image_keys=image_key,
120-
remove_empty=True,
121-
),
122-
RandFlipBoxd(
123-
image_keys=[image_key],
124-
box_keys=[box_key],
125-
box_ref_image_keys=[image_key],
129+
RandFlipd(
130+
keys=[image_key],
126131
prob=0.5,
127132
spatial_axis=0,
128133
),
129-
RandFlipBoxd(
130-
image_keys=[image_key],
131-
box_keys=[box_key],
132-
box_ref_image_keys=[image_key],
134+
RandFlipd(
135+
keys=[image_key],
133136
prob=0.5,
134137
spatial_axis=1,
135138
),
136-
RandFlipBoxd(
137-
image_keys=[image_key],
138-
box_keys=[box_key],
139-
box_ref_image_keys=[image_key],
139+
RandFlipd(
140+
keys=[image_key],
140141
prob=0.5,
141142
spatial_axis=2,
142143
),
143-
RandRotateBox90d(
144-
image_keys=[image_key],
145-
box_keys=[box_key],
146-
box_ref_image_keys=[image_key],
144+
RandRotate90d(
145+
keys=[image_key],
147146
prob=0.75,
148147
max_k=3,
149148
spatial_axes=(0, 1),
150149
),
150+
# apply the same affine matrix which already applied on the images to the points
151+
ApplyTransformToPointsd(keys=[point_key], refer_key=image_key, affine_lps_to_ras=affine_lps_to_ras),
152+
# convert points back to boxes
153+
ConvertPointsToBoxesd(keys=[point_key]),
154+
ClipBoxToImaged(
155+
box_keys=box_key,
156+
label_keys=[label_key],
157+
box_ref_image_keys=image_key,
158+
remove_empty=True,
159+
),
151160
BoxToMaskd(
152161
box_keys=[box_key],
153162
label_keys=[label_key],
@@ -184,7 +193,7 @@ def generate_detection_train_transform(
184193
RandScaleIntensityd(keys=[image_key], prob=0.15, factors=0.25),
185194
RandShiftIntensityd(keys=[image_key], prob=0.15, offsets=0.1),
186195
RandAdjustContrastd(keys=[image_key], prob=0.3, gamma=(0.7, 1.5)),
187-
EnsureTyped(keys=[image_key, box_key], dtype=compute_dtype),
196+
EnsureTyped(keys=[image_key], dtype=compute_dtype),
188197
EnsureTyped(keys=[label_key], dtype=torch.long),
189198
]
190199
)
@@ -307,3 +316,73 @@ def generate_detection_inference_transform(
307316
]
308317
)
309318
return test_transforms, post_transforms
319+
320+
321+
class GenerateExtendedBoxMask(MapTransform):
322+
"""
323+
Generate box mask based on the input boxes.
324+
"""
325+
326+
def __init__(
327+
self,
328+
keys: KeysCollection,
329+
image_key: str,
330+
spatial_size: tuple[int, int, int],
331+
whole_box: bool,
332+
mask_image_key: str = "mask_image",
333+
) -> None:
334+
"""
335+
Args:
336+
keys: keys of the corresponding items to be transformed.
337+
image_key: key for the image data in the dictionary.
338+
spatial_size: size of the spatial dimensions of the mask.
339+
whole_box: whether to use the whole box for generating the mask.
340+
mask_image_key: key to store the generated box mask.
341+
"""
342+
super().__init__(keys)
343+
self.image_key = image_key
344+
self.spatial_size = spatial_size
345+
self.whole_box = whole_box
346+
self.mask_image_key = mask_image_key
347+
348+
def generate_fg_center_boxes_np(self, boxes, image_size, whole_box=True):
349+
# We don't require crop center to be within the boxes.
350+
# As along as the cropped patch contains a box, it is considered as a foreground patch.
351+
# Positions within extended_boxes are crop centers for foreground patches
352+
spatial_dims = len(image_size)
353+
boxes_np, *_ = convert_data_type(boxes, np.ndarray)
354+
355+
extended_boxes = np.zeros_like(boxes_np, dtype=int)
356+
boxes_start = np.ceil(boxes_np[:, :spatial_dims]).astype(int)
357+
boxes_stop = np.floor(boxes_np[:, spatial_dims:]).astype(int)
358+
for axis in range(spatial_dims):
359+
if not whole_box:
360+
extended_boxes[:, axis] = boxes_start[:, axis] - self.spatial_size[axis] // 2 + 1
361+
extended_boxes[:, axis + spatial_dims] = boxes_stop[:, axis] + self.spatial_size[axis] // 2 - 1
362+
else:
363+
# extended box start
364+
extended_boxes[:, axis] = boxes_stop[:, axis] - self.spatial_size[axis] // 2 - 1
365+
extended_boxes[:, axis] = np.minimum(extended_boxes[:, axis], boxes_start[:, axis])
366+
# extended box stop
367+
extended_boxes[:, axis + spatial_dims] = extended_boxes[:, axis] + self.spatial_size[axis] // 2
368+
extended_boxes[:, axis + spatial_dims] = np.maximum(
369+
extended_boxes[:, axis + spatial_dims], boxes_stop[:, axis]
370+
)
371+
extended_boxes, _ = clip_boxes_to_image(extended_boxes, image_size, remove_empty=True) # type: ignore
372+
return extended_boxes
373+
374+
def generate_mask_img(self, boxes, image_size, whole_box=True):
375+
extended_boxes_np = self.generate_fg_center_boxes_np(boxes, image_size, whole_box)
376+
mask_img = convert_box_to_mask(
377+
extended_boxes_np, np.ones(extended_boxes_np.shape[0]), image_size, bg_label=0, ellipse_mask=False
378+
)
379+
mask_img = np.amax(mask_img, axis=0, keepdims=True)[0:1, ...]
380+
return mask_img
381+
382+
def __call__(self, data):
383+
d = dict(data)
384+
for key in self.key_iterator(d):
385+
image = d[self.image_key]
386+
boxes = d[key]
387+
data[self.mask_image_key] = self.generate_mask_img(boxes, image.shape[1:], whole_box=self.whole_box)
388+
return data

detection/luna16_prepare_images.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ def main():
4747

4848
monai.config.print_config()
4949

50-
env_dict = json.load(open(args.environment_file, "r"))
51-
config_dict = json.load(open(args.config_file, "r"))
50+
with open(args.environment_file, "r") as env_file:
51+
env_dict = json.load(env_file)
52+
with open(args.config_file, "r") as config_file:
53+
config_dict = json.load(config_file)
5254

5355
for k, v in env_dict.items():
5456
setattr(args, k, v)

detection/luna16_training.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ def main():
7878
torch.backends.cudnn.benchmark = True
7979
torch.set_num_threads(4)
8080

81-
env_dict = json.load(open(args.environment_file, "r"))
82-
config_dict = json.load(open(args.config_file, "r"))
81+
with open(args.environment_file, "r") as env_file:
82+
env_dict = json.load(env_file)
83+
with open(args.config_file, "r") as config_file:
84+
config_dict = json.load(config_file)
8385

8486
for k, v in env_dict.items():
8587
setattr(args, k, v)
@@ -103,6 +105,7 @@ def main():
103105
intensity_transform,
104106
args.patch_size,
105107
args.batch_size,
108+
point_key="points",
106109
affine_lps_to_ras=True,
107110
amp=amp,
108111
)
@@ -233,7 +236,7 @@ def main():
233236
)
234237
after_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
235238
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=10, after_scheduler=after_scheduler)
236-
scaler = torch.cuda.amp.GradScaler() if amp else None
239+
scaler = torch.amp.GradScaler("cuda") if amp else None
237240
optimizer.zero_grad()
238241
optimizer.step()
239242

@@ -279,7 +282,7 @@ def main():
279282
param.grad = None
280283

281284
if amp and (scaler is not None):
282-
with torch.cuda.amp.autocast():
285+
with torch.amp.autocast("cuda"):
283286
outputs = detector(inputs, targets)
284287
loss = w_cls * outputs[detector.cls_key] + outputs[detector.box_reg_key]
285288
scaler.scale(loss).backward()

0 commit comments

Comments
 (0)