From 71ebd912518582412772de8a9e96f69a406f1376 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 16 Sep 2021 14:03:43 +0100 Subject: [PATCH] Torch: `map_binary_to_indices`, `map_classes_to_indices`, `correct_crop_centers`, `generate_pos_neg_label_crop_centers`, `generate_label_classes_crop_centers` (#2958) torch map_binary_to_indices, map_classes_to_indices, correct_crop_centers, generate_pos_neg_label_crop_centers, generate_label_classes_crop_centers --- monai/transforms/__init__.py | 13 +- monai/transforms/croppad/array.py | 12 +- monai/transforms/croppad/dictionary.py | 10 +- monai/transforms/utility/array.py | 7 +- monai/transforms/utils.py | 74 ++++--- .../utils_pytorch_numpy_unification.py | 97 +++++++++ tests/test_correct_crop_centers.py | 39 ++++ ...est_generate_label_classes_crop_centers.py | 24 ++- ...est_generate_pos_neg_label_crop_centers.py | 57 ++++-- tests/test_map_binary_to_indices.py | 84 ++++---- tests/test_map_classes_to_indices.py | 189 ++++++++++++------ 11 files changed, 430 insertions(+), 176 deletions(-) create mode 100644 tests/test_correct_crop_centers.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 8a32b9e0b8..8f76a4b7a6 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -524,4 +524,15 @@ weighted_patch_samples, zero_margins, ) -from .utils_pytorch_numpy_unification import clip, in1d, moveaxis, percentile, where +from .utils_pytorch_numpy_unification import ( + any_np_pt, + clip, + floor_divide, + in1d, + moveaxis, + nonzero, + percentile, + ravel, + unravel_index, + where, +) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index d3cec35d93..fac14f4582 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -910,16 +910,18 @@ def randomize( image: Optional[np.ndarray] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) + fg_indices_: np.ndarray + bg_indices_: np.ndarray if fg_indices is None or bg_indices is None: if self.fg_indices is not None and self.bg_indices is not None: fg_indices_ = self.fg_indices bg_indices_ = self.bg_indices else: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore else: fg_indices_ = fg_indices bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( + self.centers = generate_pos_neg_label_crop_centers( # type: ignore self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) @@ -1052,15 +1054,15 @@ def randomize( image: Optional[np.ndarray] = None, ) -> None: self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) - indices_: List[np.ndarray] + indices_: Sequence[np.ndarray] if indices is None: if self.indices is not None: indices_ = self.indices else: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore else: indices_ = indices - self.centers = generate_label_classes_crop_centers( + self.centers = generate_label_classes_crop_centers( # type: ignore self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 233f1b6edf..a504b23179 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -1100,13 +1100,15 @@ def randomize( bg_indices: Optional[np.ndarray] = None, image: Optional[np.ndarray] = None, ) -> None: + fg_indices_: np.ndarray + bg_indices_: np.ndarray self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) if fg_indices is None or bg_indices is None: - fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) + fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold) # type: ignore else: fg_indices_ = fg_indices bg_indices_ = bg_indices - self.centers = generate_pos_neg_label_crop_centers( + self.centers = generate_pos_neg_label_crop_centers( # type: ignore self.spatial_size, self.num_samples, self.pos_ratio, label.shape[1:], fg_indices_, bg_indices_, self.R ) @@ -1283,10 +1285,10 @@ def randomize( self.spatial_size = fall_back_tuple(self.spatial_size, default=label.shape[1:]) indices_: List[np.ndarray] if indices is None: - indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + indices_ = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore else: indices_ = indices - self.centers = generate_label_classes_crop_centers( + self.centers = generate_label_classes_crop_centers( # type: ignore self.spatial_size, self.num_samples, label.shape[1:], indices_, self.ratios, self.R ) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9109fb04c5..848a782533 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -810,12 +810,14 @@ def __call__( output_shape: expected shape of output indices. if None, use `self.output_shape` instead. """ + fg_indices: np.ndarray + bg_indices: np.ndarray label, *_ = convert_data_type(label, np.ndarray) # type: ignore if image is not None: image, *_ = convert_data_type(image, np.ndarray) # type: ignore if output_shape is None: output_shape = self.output_shape - fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) + fg_indices, bg_indices = map_binary_to_indices(label, image, self.image_threshold) # type: ignore if output_shape is not None: fg_indices = np.stack([np.unravel_index(i, output_shape) for i in fg_indices]) bg_indices = np.stack([np.unravel_index(i, output_shape) for i in bg_indices]) @@ -868,7 +870,8 @@ def __call__( if output_shape is None: output_shape = self.output_shape - indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) + indices: List[np.ndarray] + indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold) # type: ignore if output_shape is not None: indices = [np.stack([np.unravel_index(i, output_shape) for i in array]) for array in indices] diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index f0be87de0b..883d5e5faa 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -26,6 +26,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.compose import Compose, OneOf from monai.transforms.transform import MapTransform, Transform +from monai.transforms.utils_pytorch_numpy_unification import any_np_pt, nonzero, ravel, unravel_index from monai.utils import ( GridSampleMode, InterpolateMode, @@ -261,10 +262,10 @@ def resize_center(img: np.ndarray, *resize_dims: Optional[int], fill_value: floa def map_binary_to_indices( - label: np.ndarray, - image: Optional[np.ndarray] = None, + label: NdarrayOrTensor, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: """ Compute the foreground and background of input label data, return the indices after fattening. For example: @@ -277,28 +278,31 @@ def map_binary_to_indices( to define background. so the output items will not map to all the voxels in the label. image_threshold: if enabled `image`, use ``image > image_threshold`` to determine the valid image content area and select background only in this area. - """ + # Prepare fg/bg indices if label.shape[0] > 1: label = label[1:] # for One-Hot format data, remove the background channel - label_flat = np.any(label, axis=0).ravel() # in case label has multiple dimensions - fg_indices = np.nonzero(label_flat)[0] + label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions + fg_indices = nonzero(label_flat) if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() - bg_indices = np.nonzero(np.logical_and(img_flat, ~label_flat))[0] + img_flat = ravel(any_np_pt(image > image_threshold, 0)) + img_flat, *_ = convert_data_type( + img_flat, type(label), device=label.device if isinstance(label, torch.Tensor) else None + ) + bg_indices = nonzero(img_flat & ~label_flat) else: - bg_indices = np.nonzero(~label_flat)[0] + bg_indices = nonzero(~label_flat) return fg_indices, bg_indices def map_classes_to_indices( - label: np.ndarray, + label: NdarrayOrTensor, num_classes: Optional[int] = None, - image: Optional[np.ndarray] = None, + image: Optional[NdarrayOrTensor] = None, image_threshold: float = 0.0, -) -> List[np.ndarray]: +) -> List[NdarrayOrTensor]: """ Filter out indices of every class of the input label data, return the indices after fattening. It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for @@ -318,11 +322,11 @@ def map_classes_to_indices( determine the valid image content area and select class indices only in this area. """ - img_flat: Optional[np.ndarray] = None + img_flat: Optional[NdarrayOrTensor] = None if image is not None: - img_flat = np.any(image > image_threshold, axis=0).ravel() + img_flat = ravel((image > image_threshold).any(0)) - indices: List[np.ndarray] = [] + indices: List[NdarrayOrTensor] = [] # assuming the first dimension is channel channels = len(label) @@ -333,9 +337,9 @@ def map_classes_to_indices( num_classes_ = num_classes for c in range(num_classes_): - label_flat = np.any(label[c : c + 1] if channels > 1 else label == c, axis=0).ravel() - label_flat = np.logical_and(img_flat, label_flat) if img_flat is not None else label_flat - indices.append(np.nonzero(label_flat)[0]) + label_flat = ravel(any_np_pt(label[c : c + 1] if channels > 1 else label == c, 0)) + label_flat = img_flat & label_flat if img_flat is not None else label_flat + indices.append(nonzero(label_flat)) return indices @@ -385,8 +389,10 @@ def weighted_patch_samples( def correct_crop_centers( - centers: List[np.ndarray], spatial_size: Union[Sequence[int], int], label_spatial_shape: Sequence[int] -) -> List[np.ndarray]: + centers: List[Union[int, torch.Tensor]], + spatial_size: Union[Sequence[int], int], + label_spatial_shape: Sequence[int], +) -> List[int]: """ Utility to correct the crop center if the crop size is bigger than the image size. @@ -419,7 +425,9 @@ def correct_crop_centers( center_i = valid_end[i] - 1 centers[i] = center_i - return centers + corrected_centers: List[int] = [c.item() if isinstance(c, torch.Tensor) else c for c in centers] # type: ignore + + return corrected_centers def generate_pos_neg_label_crop_centers( @@ -427,10 +435,10 @@ def generate_pos_neg_label_crop_centers( num_samples: int, pos_ratio: float, label_spatial_shape: Sequence[int], - fg_indices: np.ndarray, - bg_indices: np.ndarray, + fg_indices: NdarrayOrTensor, + bg_indices: NdarrayOrTensor, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: +) -> List[List[int]]: """ Generate valid sample locations based on the label with option for specifying foreground ratio Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -453,11 +461,12 @@ def generate_pos_neg_label_crop_centers( rand_state = np.random.random.__self__ # type: ignore centers = [] - fg_indices, bg_indices = np.asarray(fg_indices), np.asarray(bg_indices) - if fg_indices.size == 0 and bg_indices.size == 0: + fg_indices = np.asarray(fg_indices) if isinstance(fg_indices, Sequence) else fg_indices + bg_indices = np.asarray(bg_indices) if isinstance(bg_indices, Sequence) else bg_indices + if len(fg_indices) == 0 and len(bg_indices) == 0: raise ValueError("No sampling location available.") - if fg_indices.size == 0 or bg_indices.size == 0: + if len(fg_indices) == 0 or len(bg_indices) == 0: warnings.warn( f"N foreground {len(fg_indices)}, N background {len(bg_indices)}," "unable to generate class balanced samples." @@ -467,7 +476,8 @@ def generate_pos_neg_label_crop_centers( for _ in range(num_samples): indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + idx = indices_to_use[random_int] + center = unravel_index(idx, label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) @@ -479,10 +489,10 @@ def generate_label_classes_crop_centers( spatial_size: Union[Sequence[int], int], num_samples: int, label_spatial_shape: Sequence[int], - indices: List[np.ndarray], + indices: Sequence[NdarrayOrTensor], ratios: Optional[List[Union[float, int]]] = None, rand_state: Optional[np.random.RandomState] = None, -) -> List[List[np.ndarray]]: +) -> List[List[int]]: """ Generate valid sample locations based on the specified ratios of label classes. Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W] @@ -508,8 +518,6 @@ def generate_label_classes_crop_centers( if any(i < 0 for i in ratios_): raise ValueError("ratios should not contain negative number.") - # ensure indices are numpy array - indices = [np.asarray(i) for i in indices] for i, array in enumerate(indices): if len(array) == 0: warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.") @@ -521,7 +529,7 @@ def generate_label_classes_crop_centers( # randomly select the indices of a class based on the ratios indices_to_use = indices[i] random_int = rand_state.randint(len(indices_to_use)) - center = np.unravel_index(indices_to_use[random_int], label_spatial_shape) + center = unravel_index(indices_to_use[random_int], label_spatial_shape) # shift center to range of valid centers center_ori = list(center) centers.append(correct_crop_centers(center_ori, spatial_size, label_spatial_shape)) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 0fb8e34ef0..8808e25265 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -15,6 +15,7 @@ import torch from monai.config.type_definitions import NdarrayOrTensor +from monai.utils.misc import is_module_ver_at_least __all__ = [ "moveaxis", @@ -22,6 +23,11 @@ "clip", "percentile", "where", + "nonzero", + "floor_divide", + "unravel_index", + "ravel", + "any_np_pt", ] @@ -119,3 +125,94 @@ def where(condition: NdarrayOrTensor, x, y) -> NdarrayOrTensor: y = torch.as_tensor(y, device=condition.device, dtype=x.dtype) result = torch.where(condition, x, y) return result + + +def nonzero(x: NdarrayOrTensor): + """`np.nonzero` with equivalent implementation for torch. + + Args: + idx: array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(x, np.ndarray): + return np.nonzero(x)[0] + return torch.nonzero(x).flatten() + + +def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: + """`np.floor_divide` with equivalent implementation for torch. + + As of pt1.8, use `torch.div(..., rounding_mode="floor")`, and + before that, use `torch.floor_divide`. + + Args: + a: first array/tensor + b: scalar to divide by + + Returns: + Element-wise floor division between two arrays/tensors. + """ + if isinstance(a, torch.Tensor): + if is_module_ver_at_least(torch, (1, 8, 0)): + return torch.div(a, b, rounding_mode="floor") + return torch.floor_divide(a, b) + else: + return np.floor_divide(a, b) + + +def unravel_index(idx, shape): + """`np.unravel_index` with equivalent implementation for torch. + + Args: + idx: index to unravel + b: shape of array/tensor + + Returns: + Index unravelled for given shape + """ + if isinstance(idx, torch.Tensor): + coord = [] + for dim in reversed(shape): + coord.insert(0, idx % dim) + idx = floor_divide(idx, dim) + return torch.stack(coord) + return np.unravel_index(np.asarray(idx, dtype=int), shape) + + +def ravel(x: NdarrayOrTensor): + """`np.ravel` with equivalent implementation for torch. + + Args: + x: array/tensor to ravel + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, torch.Tensor): + if hasattr(torch, "ravel"): + return x.ravel() + return x.flatten().contiguous() + return np.ravel(x) + + +def any_np_pt(x: NdarrayOrTensor, axis: int): + """`np.any` with equivalent implementation for torch. + + For pytorch, convert to boolean for compatibility with older versions. + + Args: + x: input array/tensor + axis: axis to perform `any` over + + Returns: + Return a contiguous flattened array/tensor. + """ + if isinstance(x, torch.Tensor): + try: + return torch.any(x, axis) + except RuntimeError: + # older versions of pytorch require the input to be cast to boolean + return torch.any(x.bool(), axis) + return np.any(x, axis) diff --git a/tests/test_correct_crop_centers.py b/tests/test_correct_crop_centers.py new file mode 100644 index 0000000000..853b3d41d3 --- /dev/null +++ b/tests/test_correct_crop_centers.py @@ -0,0 +1,39 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms.utils import correct_crop_centers +from tests.utils import assert_allclose + +TESTS = [ + [ + [1, 5, 0], + [2, 2, 2], + [10, 10, 10], + ], +] + + +class TestCorrectCropCenters(unittest.TestCase): + @parameterized.expand(TESTS) + def test_torch(self, spatial_size, centers, label_spatial_shape): + result1 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + centers = [torch.tensor(i) for i in centers] + result2 = correct_crop_centers(centers, spatial_size, label_spatial_shape) + assert_allclose(result1, result2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generate_label_classes_crop_centers.py b/tests/test_generate_label_classes_crop_centers.py index 38f2a3e0d1..cc068504bf 100644 --- a/tests/test_generate_label_classes_crop_centers.py +++ b/tests/test_generate_label_classes_crop_centers.py @@ -10,11 +10,13 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_label_classes_crop_centers +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose TEST_CASE_1 = [ { @@ -23,7 +25,6 @@ "ratios": [1, 2], "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 2, @@ -37,7 +38,6 @@ "ratios": None, "label_spatial_shape": [3, 3, 3], "indices": [[3, 12, 21], [1, 9, 18]], - "rand_state": np.random.RandomState(), }, list, 1, @@ -48,10 +48,20 @@ class TestGenerateLabelClassesCropCenters(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_label_classes_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + input_data["indices"] = p(input_data["indices"]) + set_determinism(0) + result = generate_label_classes_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_generate_pos_neg_label_crop_centers.py b/tests/test_generate_pos_neg_label_crop_centers.py index 40181aa9ea..b263f10e55 100644 --- a/tests/test_generate_pos_neg_label_crop_centers.py +++ b/tests/test_generate_pos_neg_label_crop_centers.py @@ -10,35 +10,50 @@ # limitations under the License. import unittest +from copy import deepcopy -import numpy as np from parameterized import parameterized from monai.transforms import generate_pos_neg_label_crop_centers - -TEST_CASE_1 = [ - { - "spatial_size": [2, 2, 2], - "num_samples": 2, - "pos_ratio": 1.0, - "label_spatial_shape": [3, 3, 3], - "fg_indices": [1, 9, 18], - "bg_indices": [3, 12, 21], - "rand_state": np.random.RandomState(), - }, - list, - 2, - 3, -] +from monai.utils.misc import set_determinism +from tests.utils import TEST_NDARRAYS, assert_allclose + +TESTS = [] +TESTS.append( + [ + { + "spatial_size": [2, 2, 2], + "num_samples": 2, + "pos_ratio": 1.0, + "label_spatial_shape": [3, 3, 3], + "fg_indices": [1, 9, 18], + "bg_indices": [3, 12, 21], + }, + list, + 2, + 3, + ] +) class TestGeneratePosNegLabelCropCenters(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_type, expected_count, expected_shape): - result = generate_pos_neg_label_crop_centers(**input_data) - self.assertIsInstance(result, expected_type) - self.assertEqual(len(result), expected_count) - self.assertEqual(len(result[0]), expected_shape) + results = [] + for p in TEST_NDARRAYS + (None,): + input_data = deepcopy(input_data) + if p is not None: + for k in ["fg_indices", "bg_indices"]: + input_data[k] = p(input_data[k]) + set_determinism(0) + result = generate_pos_neg_label_crop_centers(**input_data) + self.assertIsInstance(result, expected_type) + self.assertEqual(len(result), expected_count) + self.assertEqual(len(result[0]), expected_shape) + # check for consistency between numpy, torch and torch.cuda + results.append(result) + if len(results) > 1: + assert_allclose(results[0], results[-1]) if __name__ == "__main__": diff --git a/tests/test_map_binary_to_indices.py b/tests/test_map_binary_to_indices.py index 1fafa6f446..2d29aa7c0d 100644 --- a/tests/test_map_binary_to_indices.py +++ b/tests/test_map_binary_to_indices.py @@ -15,50 +15,58 @@ from parameterized import parameterized from monai.transforms import map_binary_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - {"label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), "image": None, "image_threshold": 0.0}, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 4, 8]), -] - -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]), - "image_threshold": 0.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_3 = [ - { - "label": np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] - -TEST_CASE_4 = [ - { - "label": np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]]), - "image": np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]]), - "image_threshold": 1.0, - }, - np.array([1, 2, 3, 5, 6, 7]), - np.array([0, 8]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + {"label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), "image": None, "image_threshold": 0.0}, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 4, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[1, 1, 1], [1, 0, 1], [1, 1, 1]]])), + "image_threshold": 0.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [3, 0, 4], [5, 6, 0]]])), + "image": p(np.array([[[3, 3, 3], [3, 1, 3], [3, 3, 3]]])), + "image_threshold": 1.0, + }, + np.array([1, 2, 3, 5, 6, 7]), + np.array([0, 8]), + ] + ) class TestMapBinaryToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand(TESTS) def test_type_shape(self, input_data, expected_fg, expected_bg): fg_indices, bg_indices = map_binary_to_indices(**input_data) - np.testing.assert_allclose(fg_indices, expected_fg) - np.testing.assert_allclose(bg_indices, expected_bg) + assert_allclose(fg_indices, expected_fg, type_test=False) + assert_allclose(bg_indices, expected_bg, type_test=False) if __name__ == "__main__": diff --git a/tests/test_map_classes_to_indices.py b/tests/test_map_classes_to_indices.py index 2320954520..a585bd006b 100644 --- a/tests/test_map_classes_to_indices.py +++ b/tests/test_map_classes_to_indices.py @@ -15,86 +15,145 @@ from parameterized import parameterized from monai.transforms import map_classes_to_indices +from tests.utils import TEST_NDARRAYS, assert_allclose -TEST_CASE_1 = [ - # test Argmax data - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 3, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + # test Argmax data + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_2 = [ - { - "label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), - "num_classes": 3, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + TESTS.append( + [ + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 3, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, + [ + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_3 = [ - # test One-Hot data - { - "label": np.array( + TESTS.append( + [ + # test One-Hot data + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + ], + ] + ) -TEST_CASE_4 = [ - { - "label": np.array( + TESTS.append( + [ + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + ] + ) + ), + "num_classes": None, + "image": p(np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]])), + "image_threshold": 60, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - ] - ), - "num_classes": None, - "image": np.array([[[132, 1434, 51], [61, 0, 133], [523, 44, 232]]]), - "image_threshold": 60, - }, - [np.array([0, 8]), np.array([1, 5, 6]), np.array([3])], -] + np.array([0, 8]), + np.array([1, 5, 6]), + np.array([3]), + ], + ] + ) -TEST_CASE_5 = [ - # test empty class - {"label": np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]]), "num_classes": 5, "image": None, "image_threshold": 0.0}, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + TESTS.append( + [ + # test empty class + { + "label": p(np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])), + "num_classes": 5, + "image": None, + "image_threshold": 0.0, + }, + [ + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) -TEST_CASE_6 = [ - # test empty class - { - "label": np.array( + TESTS.append( + [ + # test empty class + { + "label": p( + np.array( + [ + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], + [[0, 1, 0], [0, 0, 1], [1, 0, 0]], + [[0, 0, 1], [1, 0, 0], [0, 1, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + ] + ) + ), + "image": None, + "image_threshold": 0.0, + }, [ - [[1, 0, 0], [0, 1, 0], [0, 0, 1]], - [[0, 1, 0], [0, 0, 1], [1, 0, 0]], - [[0, 0, 1], [1, 0, 0], [0, 1, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ), - "image": None, - "image_threshold": 0.0, - }, - [np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7]), np.array([]), np.array([])], -] + np.array([0, 4, 8]), + np.array([1, 5, 6]), + np.array([2, 3, 7]), + np.array([]), + np.array([]), + ], + ] + ) class TestMapClassesToIndices(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand(TESTS) def test_value(self, input_data, expected_indices): indices = map_classes_to_indices(**input_data) for i, e in zip(indices, expected_indices): - np.testing.assert_allclose(i, e) + assert_allclose(i, e, type_test=False) if __name__ == "__main__":