Skip to content

Commit

Permalink
Merge branch 'main' into ddpm_tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Sep 6, 2024
2 parents ece5ad5 + cb81ed2 commit 0098607
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 41 deletions.
149 changes: 114 additions & 35 deletions detection/generate_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,28 @@
RandRotated,
RandScaleIntensityd,
RandShiftIntensityd,
RandCropByPosNegLabeld,
RandZoomd,
RandFlipd,
RandRotate90d,
MapTransform,
)
from monai.transforms.utility.dictionary import ApplyTransformToPointsd
from monai.transforms.spatial.dictionary import ConvertBoxToPointsd, ConvertPointsToBoxesd
from monai.apps.detection.transforms.dictionary import (
AffineBoxToImageCoordinated,
AffineBoxToWorldCoordinated,
BoxToMaskd,
ClipBoxToImaged,
ConvertBoxToStandardModed,
MaskToBoxd,
RandCropBoxByPosNegLabeld,
RandFlipBoxd,
RandRotateBox90d,
RandZoomBoxd,
ConvertBoxModed,
StandardizeEmptyBoxd,
)
from monai.config import KeysCollection
from monai.utils.type_conversion import convert_data_type
from monai.data.box_utils import clip_boxes_to_image
from monai.apps.detection.transforms.box_ops import convert_box_to_mask


def generate_detection_train_transform(
Expand All @@ -49,6 +56,7 @@ def generate_detection_train_transform(
intensity_transform,
patch_size,
batch_size,
point_key="points",
affine_lps_to_ras=False,
amp=True,
):
Expand All @@ -59,6 +67,7 @@ def generate_detection_train_transform(
image_key: the key to represent images in the input json files
box_key: the key to represent boxes in the input json files
label_key: the key to represent box labels in the input json files
point_key: the key to represent points to save the box coordinates
gt_box_mode: ground truth box mode in the input json files
intensity_transform: transform to scale image intensities,
usually ScaleIntensityRanged for CT images, and NormalizeIntensityd for MR images.
Expand Down Expand Up @@ -87,67 +96,67 @@ def generate_detection_train_transform(
intensity_transform,
EnsureTyped(keys=[image_key], dtype=torch.float16),
ConvertBoxToStandardModed(box_keys=[box_key], mode=gt_box_mode),
ConvertBoxToPointsd(keys=[box_key]),
AffineBoxToImageCoordinated(
box_keys=[box_key],
box_ref_image_keys=image_key,
image_meta_key_postfix="meta_dict",
affine_lps_to_ras=affine_lps_to_ras,
),
RandCropBoxByPosNegLabeld(
image_keys=[image_key],
box_keys=box_key,
label_keys=label_key,
# generate box mask based on the input boxes which used for cropping
GenerateExtendedBoxMask(
keys=box_key,
image_key=image_key,
spatial_size=patch_size,
whole_box=True,
),
RandCropByPosNegLabeld(
keys=[image_key],
label_key="mask_image",
spatial_size=patch_size,
num_samples=batch_size,
pos=1,
neg=1,
),
RandZoomBoxd(
image_keys=[image_key],
box_keys=[box_key],
box_ref_image_keys=[image_key],
RandZoomd(
keys=[image_key],
prob=0.2,
min_zoom=0.7,
max_zoom=1.4,
padding_mode="constant",
keep_size=True,
),
ClipBoxToImaged(
box_keys=box_key,
label_keys=[label_key],
box_ref_image_keys=image_key,
remove_empty=True,
),
RandFlipBoxd(
image_keys=[image_key],
box_keys=[box_key],
box_ref_image_keys=[image_key],
RandFlipd(
keys=[image_key],
prob=0.5,
spatial_axis=0,
),
RandFlipBoxd(
image_keys=[image_key],
box_keys=[box_key],
box_ref_image_keys=[image_key],
RandFlipd(
keys=[image_key],
prob=0.5,
spatial_axis=1,
),
RandFlipBoxd(
image_keys=[image_key],
box_keys=[box_key],
box_ref_image_keys=[image_key],
RandFlipd(
keys=[image_key],
prob=0.5,
spatial_axis=2,
),
RandRotateBox90d(
image_keys=[image_key],
box_keys=[box_key],
box_ref_image_keys=[image_key],
RandRotate90d(
keys=[image_key],
prob=0.75,
max_k=3,
spatial_axes=(0, 1),
),
# apply the same affine matrix which already applied on the images to the points
ApplyTransformToPointsd(keys=[point_key], refer_key=image_key, affine_lps_to_ras=affine_lps_to_ras),
# convert points back to boxes
ConvertPointsToBoxesd(keys=[point_key]),
ClipBoxToImaged(
box_keys=box_key,
label_keys=[label_key],
box_ref_image_keys=image_key,
remove_empty=True,
),
BoxToMaskd(
box_keys=[box_key],
label_keys=[label_key],
Expand Down Expand Up @@ -184,7 +193,7 @@ def generate_detection_train_transform(
RandScaleIntensityd(keys=[image_key], prob=0.15, factors=0.25),
RandShiftIntensityd(keys=[image_key], prob=0.15, offsets=0.1),
RandAdjustContrastd(keys=[image_key], prob=0.3, gamma=(0.7, 1.5)),
EnsureTyped(keys=[image_key, box_key], dtype=compute_dtype),
EnsureTyped(keys=[image_key], dtype=compute_dtype),
EnsureTyped(keys=[label_key], dtype=torch.long),
]
)
Expand Down Expand Up @@ -307,3 +316,73 @@ def generate_detection_inference_transform(
]
)
return test_transforms, post_transforms


class GenerateExtendedBoxMask(MapTransform):
"""
Generate box mask based on the input boxes.
"""

def __init__(
self,
keys: KeysCollection,
image_key: str,
spatial_size: tuple[int, int, int],
whole_box: bool,
mask_image_key: str = "mask_image",
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
image_key: key for the image data in the dictionary.
spatial_size: size of the spatial dimensions of the mask.
whole_box: whether to use the whole box for generating the mask.
mask_image_key: key to store the generated box mask.
"""
super().__init__(keys)
self.image_key = image_key
self.spatial_size = spatial_size
self.whole_box = whole_box
self.mask_image_key = mask_image_key

def generate_fg_center_boxes_np(self, boxes, image_size, whole_box=True):
# We don't require crop center to be within the boxes.
# As along as the cropped patch contains a box, it is considered as a foreground patch.
# Positions within extended_boxes are crop centers for foreground patches
spatial_dims = len(image_size)
boxes_np, *_ = convert_data_type(boxes, np.ndarray)

extended_boxes = np.zeros_like(boxes_np, dtype=int)
boxes_start = np.ceil(boxes_np[:, :spatial_dims]).astype(int)
boxes_stop = np.floor(boxes_np[:, spatial_dims:]).astype(int)
for axis in range(spatial_dims):
if not whole_box:
extended_boxes[:, axis] = boxes_start[:, axis] - self.spatial_size[axis] // 2 + 1
extended_boxes[:, axis + spatial_dims] = boxes_stop[:, axis] + self.spatial_size[axis] // 2 - 1
else:
# extended box start
extended_boxes[:, axis] = boxes_stop[:, axis] - self.spatial_size[axis] // 2 - 1
extended_boxes[:, axis] = np.minimum(extended_boxes[:, axis], boxes_start[:, axis])
# extended box stop
extended_boxes[:, axis + spatial_dims] = extended_boxes[:, axis] + self.spatial_size[axis] // 2
extended_boxes[:, axis + spatial_dims] = np.maximum(
extended_boxes[:, axis + spatial_dims], boxes_stop[:, axis]
)
extended_boxes, _ = clip_boxes_to_image(extended_boxes, image_size, remove_empty=True) # type: ignore
return extended_boxes

def generate_mask_img(self, boxes, image_size, whole_box=True):
extended_boxes_np = self.generate_fg_center_boxes_np(boxes, image_size, whole_box)
mask_img = convert_box_to_mask(
extended_boxes_np, np.ones(extended_boxes_np.shape[0]), image_size, bg_label=0, ellipse_mask=False
)
mask_img = np.amax(mask_img, axis=0, keepdims=True)[0:1, ...]
return mask_img

def __call__(self, data):
d = dict(data)
for key in self.key_iterator(d):
image = d[self.image_key]
boxes = d[key]
data[self.mask_image_key] = self.generate_mask_img(boxes, image.shape[1:], whole_box=self.whole_box)
return data
6 changes: 4 additions & 2 deletions detection/luna16_prepare_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def main():

monai.config.print_config()

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))
with open(args.environment_file, "r") as env_file:
env_dict = json.load(env_file)
with open(args.config_file, "r") as config_file:
config_dict = json.load(config_file)

for k, v in env_dict.items():
setattr(args, k, v)
Expand Down
11 changes: 7 additions & 4 deletions detection/luna16_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def main():
torch.backends.cudnn.benchmark = True
torch.set_num_threads(4)

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))
with open(args.environment_file, "r") as env_file:
env_dict = json.load(env_file)
with open(args.config_file, "r") as config_file:
config_dict = json.load(config_file)

for k, v in env_dict.items():
setattr(args, k, v)
Expand All @@ -103,6 +105,7 @@ def main():
intensity_transform,
args.patch_size,
args.batch_size,
point_key="points",
affine_lps_to_ras=True,
amp=amp,
)
Expand Down Expand Up @@ -233,7 +236,7 @@ def main():
)
after_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=10, after_scheduler=after_scheduler)
scaler = torch.cuda.amp.GradScaler() if amp else None
scaler = torch.amp.GradScaler("cuda") if amp else None
optimizer.zero_grad()
optimizer.step()

Expand Down Expand Up @@ -279,7 +282,7 @@ def main():
param.grad = None

if amp and (scaler is not None):
with torch.cuda.amp.autocast():
with torch.amp.autocast("cuda"):
outputs = detector(inputs, targets)
loss = w_cls * outputs[detector.cls_key] + outputs[detector.box_reg_key]
scaler.scale(loss).backward()
Expand Down

0 comments on commit 0098607

Please sign in to comment.