diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4bc09b59a5..9109fb04c5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -33,17 +33,18 @@ ) from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis from monai.utils import ( + convert_data_type, convert_to_cupy, convert_to_numpy, convert_to_tensor, ensure_tuple, + get_equivalent_dtype, look_up_option, min_version, optional_import, ) from monai.utils.enums import TransformBackends from monai.utils.misc import is_module_ver_at_least -from monai.utils.type_conversion import convert_data_type PILImageImage, has_pil = optional_import("PIL.Image", name="Image") pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray") @@ -342,15 +343,16 @@ class ToTensor(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, device: Optional[torch.device] = None) -> None: + def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None: super().__init__() + self.dtype = dtype self.device = device def __call__(self, img: NdarrayOrTensor) -> torch.Tensor: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore + return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=True) # type: ignore class EnsureType(Transform): @@ -362,14 +364,21 @@ class EnsureType(Transform): Args: data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. device: for Tensor data type, specify the target device. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, data_type: str = "tensor", device: Optional[torch.device] = None) -> None: + def __init__( + self, + data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, + device: Optional[torch.device] = None, + ) -> None: self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) + self.dtype = dtype self.device = device def __call__(self, data: NdarrayOrTensor): @@ -381,7 +390,12 @@ def __call__(self, data: NdarrayOrTensor): if applicable. """ - return convert_to_tensor(data, device=self.device) if self.data_type == "tensor" else convert_to_numpy(data) + if self.data_type == "tensor": + dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor) + return convert_to_tensor(data, dtype=dtype_, device=self.device) + else: + dtype_ = get_equivalent_dtype(self.dtype, np.ndarray) + return convert_to_numpy(data, dtype=dtype_) class ToNumpy(Transform): @@ -391,11 +405,15 @@ class ToNumpy(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + def __init__(self, dtype: Optional[DtypeLike] = None) -> None: + super().__init__() + self.dtype = dtype + def __call__(self, img: NdarrayOrTensor) -> np.ndarray: """ Apply the transform to `img` and make it contiguous. """ - return convert_to_numpy(img) # type: ignore + return convert_to_numpy(img, dtype=self.dtype) # type: ignore class ToCupy(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 3a5be20e8b..a3c51fe3f2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -443,15 +443,23 @@ class ToTensord(MapTransform, InvertibleTransform): backend = ToTensor.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data content type to convert, for example: torch.float, etc. + device: specify the target device to put the Tensor data. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToTensor() + self.converter = ToTensor(dtype=dtype, device=device) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -491,6 +499,7 @@ def __init__( self, keys: KeysCollection, data_type: str = "tensor", + dtype: Optional[Union[DtypeLike, torch.dtype]] = None, device: Optional[torch.device] = None, allow_missing_keys: bool = False, ) -> None: @@ -499,11 +508,12 @@ def __init__( keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` data_type: target data type to convert, should be "tensor" or "numpy". + dtype: target data content type to convert, for example: np.float32, torch.float, etc. device: for Tensor data type, specify the target device. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type, device=device) + self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -530,15 +540,21 @@ class ToNumpyd(MapTransform): backend = ToNumpy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + def __init__( + self, + keys: KeysCollection, + dtype: Optional[DtypeLike] = None, + allow_missing_keys: bool = False, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` + dtype: target data type when converting to numpy array. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = ToNumpy() + self.converter = ToNumpy(dtype=dtype) def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: d = dict(data) @@ -554,13 +570,13 @@ class ToCupyd(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - allow_missing_keys: don't raise exception if key is missing. dtype: data type specifier. It is inferred from the input by default. + allow_missing_keys: don't raise exception if key is missing. """ backend = ToCupy.backend - def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, dtype=None) -> None: + def __init__(self, keys: KeysCollection, dtype=None, allow_missing_keys: bool = False) -> None: super().__init__(keys, allow_missing_keys) self.converter = ToCupy(dtype=dtype) diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 8a54986633..773f1cbc37 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -61,6 +61,8 @@ def get_equivalent_dtype(dtype, data_type): im = torch.tensor(1) dtype = get_equivalent_dtype(np.float32, type(im)) """ + if dtype is None: + return None if data_type is torch.Tensor: if type(dtype) is torch.dtype: return dtype @@ -84,7 +86,12 @@ def get_dtype(data: Any): return type(data) -def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None): +def convert_to_tensor( + data, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + wrap_sequence: bool = False, +): """ Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple, recursively check every item and convert it to PyTorch Tensor. @@ -93,36 +100,41 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch. data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original. for dictionary, list or tuple, convert every item to a Tensor if applicable. - wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. - If `True`, then `[1, 2]` -> `tensor([1, 2])`. + dtype: target data type to when converting to Tensor. + device: target device to put the converted Tensor data. + wrap_sequence: if `False`, then lists will recursively call this function. + E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ if isinstance(data, torch.Tensor): - return data.contiguous().to(device) + return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: # https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13 if re.search(r"[SaUO]", data.dtype.str) is None: # numpy array with 0 dims is also sequence iterable, # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims - return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device) - elif has_cp and isinstance(data, cp_ndarray): - return torch.as_tensor(data, device=device) - elif isinstance(data, (float, int, bool)): - return torch.as_tensor(data, device=device) - elif isinstance(data, Sequence) and wrap_sequence: - return torch.as_tensor(data, device=device) + if data.ndim > 0: + data = np.ascontiguousarray(data) + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore + elif ( + has_cp + and isinstance(data, cp_ndarray) + or isinstance(data, (float, int, bool)) + or (isinstance(data, Sequence) and wrap_sequence) + ): + return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore elif isinstance(data, list): - return [convert_to_tensor(i, device=device) for i in data] + return [convert_to_tensor(i, dtype=dtype, device=device) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_tensor(i, device=device) for i in data) + return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data) elif isinstance(data, dict): - return {k: convert_to_tensor(v, device=device) for k, v in data.items()} + return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()} return data -def convert_to_numpy(data, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. @@ -131,23 +143,22 @@ def convert_to_numpy(data, wrap_sequence: bool = False): data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc. will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original. for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to numpy array. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ if isinstance(data, torch.Tensor): - data = data.detach().cpu().numpy() + data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy() elif has_cp and isinstance(data, cp_ndarray): - data = cp.asnumpy(data) - elif isinstance(data, (float, int, bool)): - data = np.asarray(data) - elif isinstance(data, Sequence) and wrap_sequence: - return np.asarray(data) + data = cp.asnumpy(data).astype(dtype) + elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence): + data = np.asarray(data, dtype=dtype) elif isinstance(data, list): - return [convert_to_numpy(i) for i in data] + return [convert_to_numpy(i, dtype=dtype) for i in data] elif isinstance(data, tuple): - return tuple(convert_to_numpy(i) for i in data) + return tuple(convert_to_numpy(i, dtype=dtype) for i in data) elif isinstance(data, dict): - return {k: convert_to_numpy(v) for k, v in data.items()} + return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()} if isinstance(data, np.ndarray) and data.ndim > 0: data = np.ascontiguousarray(data) @@ -165,16 +176,16 @@ def convert_to_cupy(data, dtype, wrap_sequence: bool = True): Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays for dictionary, list or tuple, convert every item to a numpy array if applicable. + dtype: target data type when converting to Cupy array. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. """ # direct calls - if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)): + if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or ( + isinstance(data, Sequence) and wrap_sequence + ): data = cp.asarray(data, dtype) - # recursive calls - elif isinstance(data, Sequence) and wrap_sequence: - return cp.asarray(data, dtype) elif isinstance(data, list): return [convert_to_cupy(i, dtype) for i in data] elif isinstance(data, tuple): @@ -224,24 +235,14 @@ def convert_data_type( output_type = output_type or orig_type - dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type) + dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type) if output_type is torch.Tensor: - if orig_type is not torch.Tensor: - data = convert_to_tensor(data) - if dtype != data.dtype: - data = data.to(dtype) - if device is not None: - data = data.to(device) + data = convert_to_tensor(data, dtype=dtype_, device=device) elif output_type is np.ndarray: - if orig_type is not np.ndarray: - data = convert_to_numpy(data) - if data is not None and dtype != data.dtype: - data = data.astype(dtype) + data = convert_to_numpy(data, dtype=dtype_) elif has_cp and output_type is cp.ndarray: - if data is not None: - data = convert_to_cupy(data, dtype) - + data = convert_to_cupy(data, dtype=dtype_) else: raise ValueError(f"Unsupported output type: {output_type}") return data, orig_type, orig_device diff --git a/tests/test_ensure_type.py b/tests/test_ensure_type.py index 86bc3db703..f09c022f74 100644 --- a/tests/test_ensure_type.py +++ b/tests/test_ensure_type.py @@ -25,7 +25,9 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureType(data_type=dtype)(test_data) + result = EnsureType(dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")(test_data) + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index e4c72d37e2..5e3e941f59 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -25,7 +25,14 @@ def test_array_input(self): test_datas.append(test_datas[-1].cuda()) for test_data in test_datas: for dtype in ("tensor", "NUMPY"): - result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"] + result = EnsureTyped( + keys="data", + data_type=dtype, + dtype=np.float32 if dtype == "NUMPY" else None, + device="cpu", + )({"data": test_data})["data"] + if dtype == "NUMPY": + self.assertTrue(result.dtype == np.float32) self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)) assert_allclose(result, test_data) self.assertTupleEqual(result.shape, (2, 2)) diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py index b48727c01d..09940e33ba 100644 --- a/tests/test_to_numpy.py +++ b/tests/test_to_numpy.py @@ -37,8 +37,9 @@ def test_numpy_input(self): test_data = np.array([[1, 2], [3, 4]]) test_data = np.rot90(test_data) self.assertFalse(test_data.flags["C_CONTIGUOUS"]) - result = ToNumpy()(test_data) + result = ToNumpy(dtype="float32")(test_data) self.assertTrue(isinstance(result, np.ndarray)) + self.assertTrue(result.dtype == np.float32) self.assertTrue(result.flags["C_CONTIGUOUS"]) assert_allclose(result, test_data) diff --git a/tests/test_to_tensor.py b/tests/test_to_tensor.py index 3d187a1dba..74acb1016c 100644 --- a/tests/test_to_tensor.py +++ b/tests/test_to_tensor.py @@ -11,8 +11,8 @@ import unittest +import torch from parameterized import parameterized -from torch import Tensor from monai.transforms import ToTensor from tests.utils import TEST_NDARRAYS, assert_allclose, optional_import @@ -35,15 +35,15 @@ class TestToTensor(unittest.TestCase): @parameterized.expand(TESTS) def test_array_input(self, test_data, expected_shape): - result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) + result = ToTensor(dtype=torch.float32, device="cpu")(test_data) + self.assertTrue(isinstance(result, torch.Tensor)) assert_allclose(result, test_data) self.assertTupleEqual(result.shape, expected_shape) @parameterized.expand(TESTS_SINGLE) def test_single_input(self, test_data): result = ToTensor()(test_data) - self.assertTrue(isinstance(result, Tensor)) + self.assertTrue(isinstance(result, torch.Tensor)) assert_allclose(result, test_data) self.assertEqual(result.ndim, 0) @@ -52,7 +52,7 @@ def test_cupy(self): test_data = [[1, 2], [3, 4]] cupy_array = cp.ascontiguousarray(cp.asarray(test_data)) result = ToTensor()(cupy_array) - self.assertTrue(isinstance(result, Tensor)) + self.assertTrue(isinstance(result, torch.Tensor)) assert_allclose(result, test_data)