Skip to content

Commit

Permalink
2951 Optimize type_conversion logic (Project-MONAI#2955)
Browse files Browse the repository at this point in the history
* [DLMED] enhance type conversion

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix CI test

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add more unit tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <[email protected]>
  • Loading branch information
Nic-Ma authored Sep 15, 2021
1 parent c624ffc commit 2f4b582
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 63 deletions.
30 changes: 24 additions & 6 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,18 @@
)
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis
from monai.utils import (
convert_data_type,
convert_to_cupy,
convert_to_numpy,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
min_version,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -342,15 +343,16 @@ class ToTensor(Transform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, device: Optional[torch.device] = None) -> None:
def __init__(self, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> None:
super().__init__()
self.dtype = dtype
self.device = device

def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
"""
Apply the transform to `img` and make it contiguous.
"""
return convert_to_tensor(img, wrap_sequence=True, device=self.device) # type: ignore
return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=True) # type: ignore


class EnsureType(Transform):
Expand All @@ -362,14 +364,21 @@ class EnsureType(Transform):
Args:
data_type: target data type to convert, should be "tensor" or "numpy".
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
device: for Tensor data type, specify the target device.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, data_type: str = "tensor", device: Optional[torch.device] = None) -> None:
def __init__(
self,
data_type: str = "tensor",
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
device: Optional[torch.device] = None,
) -> None:
self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
self.dtype = dtype
self.device = device

def __call__(self, data: NdarrayOrTensor):
Expand All @@ -381,7 +390,12 @@ def __call__(self, data: NdarrayOrTensor):
if applicable.
"""
return convert_to_tensor(data, device=self.device) if self.data_type == "tensor" else convert_to_numpy(data)
if self.data_type == "tensor":
dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor)
return convert_to_tensor(data, dtype=dtype_, device=self.device)
else:
dtype_ = get_equivalent_dtype(self.dtype, np.ndarray)
return convert_to_numpy(data, dtype=dtype_)


class ToNumpy(Transform):
Expand All @@ -391,11 +405,15 @@ class ToNumpy(Transform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, dtype: Optional[DtypeLike] = None) -> None:
super().__init__()
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor) -> np.ndarray:
"""
Apply the transform to `img` and make it contiguous.
"""
return convert_to_numpy(img) # type: ignore
return convert_to_numpy(img, dtype=self.dtype) # type: ignore


class ToCupy(Transform):
Expand Down
30 changes: 23 additions & 7 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,23 @@ class ToTensord(MapTransform, InvertibleTransform):

backend = ToTensor.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
def __init__(
self,
keys: KeysCollection,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
dtype: target data content type to convert, for example: torch.float, etc.
device: specify the target device to put the Tensor data.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = ToTensor()
self.converter = ToTensor(dtype=dtype, device=device)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down Expand Up @@ -491,6 +499,7 @@ def __init__(
self,
keys: KeysCollection,
data_type: str = "tensor",
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
device: Optional[torch.device] = None,
allow_missing_keys: bool = False,
) -> None:
Expand All @@ -499,11 +508,12 @@ def __init__(
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
data_type: target data type to convert, should be "tensor" or "numpy".
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
device: for Tensor data type, specify the target device.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = EnsureType(data_type=data_type, device=device)
self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand All @@ -530,15 +540,21 @@ class ToNumpyd(MapTransform):

backend = ToNumpy.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
def __init__(
self,
keys: KeysCollection,
dtype: Optional[DtypeLike] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
dtype: target data type when converting to numpy array.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = ToNumpy()
self.converter = ToNumpy(dtype=dtype)

def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = dict(data)
Expand All @@ -554,13 +570,13 @@ class ToCupyd(MapTransform):
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
allow_missing_keys: don't raise exception if key is missing.
dtype: data type specifier. It is inferred from the input by default.
allow_missing_keys: don't raise exception if key is missing.
"""

backend = ToCupy.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, dtype=None) -> None:
def __init__(self, keys: KeysCollection, dtype=None, allow_missing_keys: bool = False) -> None:
super().__init__(keys, allow_missing_keys)
self.converter = ToCupy(dtype=dtype)

Expand Down
85 changes: 43 additions & 42 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def get_equivalent_dtype(dtype, data_type):
im = torch.tensor(1)
dtype = get_equivalent_dtype(np.float32, type(im))
"""
if dtype is None:
return None
if data_type is torch.Tensor:
if type(dtype) is torch.dtype:
return dtype
Expand All @@ -84,7 +86,12 @@ def get_dtype(data: Any):
return type(data)


def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.device] = None):
def convert_to_tensor(
data,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = False,
):
"""
Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple,
recursively check every item and convert it to PyTorch Tensor.
Expand All @@ -93,36 +100,41 @@ def convert_to_tensor(data, wrap_sequence: bool = False, device: Optional[torch.
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a Tensor if applicable.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`.
If `True`, then `[1, 2]` -> `tensor([1, 2])`.
dtype: target data type to when converting to Tensor.
device: target device to put the converted Tensor data.
wrap_sequence: if `False`, then lists will recursively call this function.
E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`.
"""
if isinstance(data, torch.Tensor):
return data.contiguous().to(device)
return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore
if isinstance(data, np.ndarray):
# skip array of string classes and object, refer to:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/_utils/collate.py#L13
if re.search(r"[SaUO]", data.dtype.str) is None:
# numpy array with 0 dims is also sequence iterable,
# `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims
return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data), device=device)
elif has_cp and isinstance(data, cp_ndarray):
return torch.as_tensor(data, device=device)
elif isinstance(data, (float, int, bool)):
return torch.as_tensor(data, device=device)
elif isinstance(data, Sequence) and wrap_sequence:
return torch.as_tensor(data, device=device)
if data.ndim > 0:
data = np.ascontiguousarray(data)
return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore
elif (
has_cp
and isinstance(data, cp_ndarray)
or isinstance(data, (float, int, bool))
or (isinstance(data, Sequence) and wrap_sequence)
):
return torch.as_tensor(data, dtype=dtype, device=device) # type: ignore
elif isinstance(data, list):
return [convert_to_tensor(i, device=device) for i in data]
return [convert_to_tensor(i, dtype=dtype, device=device) for i in data]
elif isinstance(data, tuple):
return tuple(convert_to_tensor(i, device=device) for i in data)
return tuple(convert_to_tensor(i, dtype=dtype, device=device) for i in data)
elif isinstance(data, dict):
return {k: convert_to_tensor(v, device=device) for k, v in data.items()}
return {k: convert_to_tensor(v, dtype=dtype, device=device) for k, v in data.items()}

return data


def convert_to_numpy(data, wrap_sequence: bool = False):
def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False):
"""
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
recursively check every item and convert it to numpy array.
Expand All @@ -131,23 +143,22 @@ def convert_to_numpy(data, wrap_sequence: bool = False):
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a numpy array if applicable.
dtype: target data type when converting to numpy array.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
"""
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy()
elif has_cp and isinstance(data, cp_ndarray):
data = cp.asnumpy(data)
elif isinstance(data, (float, int, bool)):
data = np.asarray(data)
elif isinstance(data, Sequence) and wrap_sequence:
return np.asarray(data)
data = cp.asnumpy(data).astype(dtype)
elif isinstance(data, (np.ndarray, float, int, bool)) or (isinstance(data, Sequence) and wrap_sequence):
data = np.asarray(data, dtype=dtype)
elif isinstance(data, list):
return [convert_to_numpy(i) for i in data]
return [convert_to_numpy(i, dtype=dtype) for i in data]
elif isinstance(data, tuple):
return tuple(convert_to_numpy(i) for i in data)
return tuple(convert_to_numpy(i, dtype=dtype) for i in data)
elif isinstance(data, dict):
return {k: convert_to_numpy(v) for k, v in data.items()}
return {k: convert_to_numpy(v, dtype=dtype) for k, v in data.items()}

if isinstance(data, np.ndarray) and data.ndim > 0:
data = np.ascontiguousarray(data)
Expand All @@ -165,16 +176,16 @@ def convert_to_cupy(data, dtype, wrap_sequence: bool = True):
Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays
for dictionary, list or tuple, convert every item to a numpy array if applicable.
dtype: target data type when converting to Cupy array.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
"""

# direct calls
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)):
if isinstance(data, (cp_ndarray, np.ndarray, torch.Tensor, float, int, bool)) or (
isinstance(data, Sequence) and wrap_sequence
):
data = cp.asarray(data, dtype)
# recursive calls
elif isinstance(data, Sequence) and wrap_sequence:
return cp.asarray(data, dtype)
elif isinstance(data, list):
return [convert_to_cupy(i, dtype) for i in data]
elif isinstance(data, tuple):
Expand Down Expand Up @@ -224,24 +235,14 @@ def convert_data_type(

output_type = output_type or orig_type

dtype = get_equivalent_dtype(dtype or get_dtype(data), output_type)
dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type)

if output_type is torch.Tensor:
if orig_type is not torch.Tensor:
data = convert_to_tensor(data)
if dtype != data.dtype:
data = data.to(dtype)
if device is not None:
data = data.to(device)
data = convert_to_tensor(data, dtype=dtype_, device=device)
elif output_type is np.ndarray:
if orig_type is not np.ndarray:
data = convert_to_numpy(data)
if data is not None and dtype != data.dtype:
data = data.astype(dtype)
data = convert_to_numpy(data, dtype=dtype_)
elif has_cp and output_type is cp.ndarray:
if data is not None:
data = convert_to_cupy(data, dtype)

data = convert_to_cupy(data, dtype=dtype_)
else:
raise ValueError(f"Unsupported output type: {output_type}")
return data, orig_type, orig_device
Expand Down
4 changes: 3 additions & 1 deletion tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def test_array_input(self):
test_datas.append(test_datas[-1].cuda())
for test_data in test_datas:
for dtype in ("tensor", "NUMPY"):
result = EnsureType(data_type=dtype)(test_data)
result = EnsureType(dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu")(test_data)
if dtype == "NUMPY":
self.assertTrue(result.dtype == np.float32)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
self.assertTupleEqual(result.shape, (2, 2))
Expand Down
9 changes: 8 additions & 1 deletion tests/test_ensure_typed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ def test_array_input(self):
test_datas.append(test_datas[-1].cuda())
for test_data in test_datas:
for dtype in ("tensor", "NUMPY"):
result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"]
result = EnsureTyped(
keys="data",
data_type=dtype,
dtype=np.float32 if dtype == "NUMPY" else None,
device="cpu",
)({"data": test_data})["data"]
if dtype == "NUMPY":
self.assertTrue(result.dtype == np.float32)
self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
assert_allclose(result, test_data)
self.assertTupleEqual(result.shape, (2, 2))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def test_numpy_input(self):
test_data = np.array([[1, 2], [3, 4]])
test_data = np.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToNumpy()(test_data)
result = ToNumpy(dtype="float32")(test_data)
self.assertTrue(isinstance(result, np.ndarray))
self.assertTrue(result.dtype == np.float32)
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data)

Expand Down
Loading

0 comments on commit 2f4b582

Please sign in to comment.