Skip to content

Commit

Permalink
Torch: map_binary_to_indices, map_classes_to_indices, `correct_cr…
Browse files Browse the repository at this point in the history
…op_centers`, `generate_pos_neg_label_crop_centers`, `generate_label_classes_crop_centers` (Project-MONAI#2958)

torch map_binary_to_indices, map_classes_to_indices, correct_crop_centers, generate_pos_neg_label_crop_centers, generate_label_classes_crop_centers
  • Loading branch information
rijobro authored Sep 16, 2021
1 parent 3b6f479 commit 71ebd91
Show file tree
Hide file tree
Showing 11 changed files with 430 additions and 176 deletions.
13 changes: 12 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,15 @@
weighted_patch_samples,
zero_margins,
)
from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where
from .utils_pytorch_numpy_unification import (
any_np_pt,
clip,
floor_divide,
in1d,
moveaxis,
nonzero,
percentile,
ravel,
unravel_index,
where,
)
12 changes: 7 additions & 5 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,18 @@ def randomize(
image: Optional[np.ndarray] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
fg_indices_: np.ndarray
bg_indices_: np.ndarray
if fg_indices is None or bg_indices is None:
if self.fg_indices is not None and self.bg_indices is not None:
fg_indices_ = self.fg_indices
bg_indices_ = self.bg_indices
else:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

Expand Down Expand Up @@ -1052,15 +1054,15 @@ def randomize(
image: Optional[np.ndarray] = None,
) -> None:
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
indices_: Sequence[np.ndarray]
if indices is None:
if self.indices is not None:
indices_ = self.indices
else:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.centers = generate_label_classes_crop_centers( # type: ignore
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

Expand Down
10 changes: 6 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,13 +1100,15 @@ def randomize(
bg_indices: Optional[np.ndarray] = None,
image: Optional[np.ndarray] = None,
) -> None:
fg_indices_: np.ndarray
bg_indices_: np.ndarray
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
if fg_indices is None or bg_indices is None:
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
else:
fg_indices_ = fg_indices
bg_indices_ = bg_indices
self.centers = generate_pos_neg_label_crop_centers(
self.centers = generate_pos_neg_label_crop_centers( # type: ignore
self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R
)

Expand Down Expand Up @@ -1283,10 +1285,10 @@ def randomize(
self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:])
indices_: List[np.ndarray]
if indices is None:
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
else:
indices_ = indices
self.centers = generate_label_classes_crop_centers(
self.centers = generate_label_classes_crop_centers( # type: ignore
self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R
)

Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,12 +810,14 @@ def __call__(
output_shape: expected shape of output indices. if None, use `self.output_shape` instead.
"""
fg_indices: np.ndarray
bg_indices: np.ndarray
label, *_ = convert_data_type(label, np.ndarray) # type: ignore
if image is not None:
image, *_ = convert_data_type(image, np.ndarray) # type: ignore
if output_shape is None:
output_shape = self.output_shape
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold)
fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) # type: ignore
if output_shape is not None:
fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices])
bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices])
Expand Down Expand Up @@ -868,7 +870,8 @@ def __call__(

if output_shape is None:
output_shape = self.output_shape
indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold)
indices: List[np.ndarray]
indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore
if output_shape is not None:
indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices]

Expand Down
74 changes: 41 additions & 33 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.compose import Compose, OneOf
from monai.transforms.transform import MapTransform, Transform
from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index
from monai.utils import (
GridSampleMode,
InterpolateMode,
Expand Down Expand Up @@ -261,10 +262,10 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa


def map_binary_to_indices(
label: np.ndarray,
image: Optional[np.ndarray] = None,
label: NdarrayOrTensor,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
) -> Tuple[np.ndarray, np.ndarray]:
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]:
"""
Compute the foreground and background of input label data, return the indices after fattening.
For example:
Expand All @@ -277,28 +278,31 @@ def map_binary_to_indices(
to define background. so the output items will not map to all the voxels in the label.
image_threshold: if enabled `image`, use ``image > image_threshold`` to
determine the valid image content area and select background only in this area.
"""

# Prepare fg/bg indices
if label.shape[0] > 1:
label = label[1:] # for One-Hot format data, remove the background channel
label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions
fg_indices = np.nonzero(label_flat)[0]
label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions
fg_indices = nonzero(label_flat)
if image is not None:
img_flat = np.any(image > image_threshold, axis=0).ravel()
bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0]
img_flat = ravel(any_np_pt(image > image_threshold, 0))
img_flat, *_ = convert_data_type(
img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None
)
bg_indices = nonzero(img_flat & ~label_flat)
else:
bg_indices = np.nonzero(~label_flat)[0]
bg_indices = nonzero(~label_flat)

return fg_indices, bg_indices


def map_classes_to_indices(
label: np.ndarray,
label: NdarrayOrTensor,
num_classes: Optional[int] = None,
image: Optional[np.ndarray] = None,
image: Optional[NdarrayOrTensor] = None,
image_threshold: float = 0.0,
) -> List[np.ndarray]:
) -> List[NdarrayOrTensor]:
"""
Filter out indices of every class of the input label data, return the indices after fattening.
It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for
Expand All @@ -318,11 +322,11 @@ def map_classes_to_indices(
determine the valid image content area and select class indices only in this area.
"""
img_flat: Optional[np.ndarray] = None
img_flat: Optional[NdarrayOrTensor] = None
if image is not None:
img_flat = np.any(image > image_threshold, axis=0).ravel()
img_flat = ravel((image > image_threshold).any(0))

indices: List[np.ndarray] = []
indices: List[NdarrayOrTensor] = []
# assuming the first dimension is channel
channels = len(label)

Expand All @@ -333,9 +337,9 @@ def map_classes_to_indices(
num_classes_ = num_classes

for c in range(num_classes_):
label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel()
label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat
indices.append(np.nonzero(label_flat)[0])
label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0))
label_flat = img_flat & label_flat if img_flat is not None else label_flat
indices.append(nonzero(label_flat))

return indices

Expand Down Expand Up @@ -385,8 +389,10 @@ def weighted_patch_samples(


def correct_crop_centers(
centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int]
) -> List[np.ndarray]:
centers: List[Union[int, torch.Tensor]],
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
) -> List[int]:
"""
Utility to correct the crop center if the crop size is bigger than the image size.
Expand Down Expand Up @@ -419,18 +425,20 @@ def correct_crop_centers(
center_i = valid_end[i] - 1
centers[i] = center_i

return centers
corrected_centers: List[int] = [c.item() if isinstance(c, torch.Tensor) else c for c in centers] # type: ignore

return corrected_centers


def generate_pos_neg_label_crop_centers(
spatial_size: Union[Sequence[int], int],
num_samples: int,
pos_ratio: float,
label_spatial_shape: Sequence[int],
fg_indices: np.ndarray,
bg_indices: np.ndarray,
fg_indices: NdarrayOrTensor,
bg_indices: NdarrayOrTensor,
rand_state: Optional[np.random.RandomState] = None,
) -> List[List[np.ndarray]]:
) -> List[List[int]]:
"""
Generate valid sample locations based on the label with option for specifying foreground ratio
Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]
Expand All @@ -453,11 +461,12 @@ def generate_pos_neg_label_crop_centers(
rand_state = np.random.random.__self__ # type: ignore

centers = []
fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices)
if fg_indices.size == 0 and bg_indices.size == 0:
fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices
bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices
if len(fg_indices) == 0 and len(bg_indices) == 0:
raise ValueError("No sampling location available.")

if fg_indices.size == 0 or bg_indices.size == 0:
if len(fg_indices) == 0 or len(bg_indices) == 0:
warnings.warn(
f"N foreground {len(fg_indices)}, N background {len(bg_indices)},"
"unable to generate class balanced samples."
Expand All @@ -467,7 +476,8 @@ def generate_pos_neg_label_crop_centers(
for _ in range(num_samples):
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices
random_int = rand_state.randint(len(indices_to_use))
center = np.unravel_index(indices_to_use[random_int], label_spatial_shape)
idx = indices_to_use[random_int]
center = unravel_index(idx, label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
Expand All @@ -479,10 +489,10 @@ def generate_label_classes_crop_centers(
spatial_size: Union[Sequence[int], int],
num_samples: int,
label_spatial_shape: Sequence[int],
indices: List[np.ndarray],
indices: Sequence[NdarrayOrTensor],
ratios: Optional[List[Union[float, int]]] = None,
rand_state: Optional[np.random.RandomState] = None,
) -> List[List[np.ndarray]]:
) -> List[List[int]]:
"""
Generate valid sample locations based on the specified ratios of label classes.
Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]
Expand All @@ -508,8 +518,6 @@ def generate_label_classes_crop_centers(
if any(i < 0 for i in ratios_):
raise ValueError("ratios should not contain negative number.")

# ensure indices are numpy array
indices = [np.asarray(i) for i in indices]
for i, array in enumerate(indices):
if len(array) == 0:
warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.")
Expand All @@ -521,7 +529,7 @@ def generate_label_classes_crop_centers(
# randomly select the indices of a class based on the ratios
indices_to_use = indices[i]
random_int = rand_state.randint(len(indices_to_use))
center = np.unravel_index(indices_to_use[random_int], label_spatial_shape)
center = unravel_index(indices_to_use[random_int], label_spatial_shape)
# shift center to range of valid centers
center_ori = list(center)
centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape))
Expand Down
Loading

0 comments on commit 71ebd91

Please sign in to comment.