diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index a75bb390cd..3a2d792dde 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -64,6 +64,7 @@ GridSamplePadMode, InterpolateMode, NumpyPadMode, + SpaceKeys, convert_to_cupy, convert_to_dst_type, convert_to_numpy, @@ -560,7 +561,7 @@ def __init__( self, axcodes: str | None = None, as_closest_canonical: bool = False, - labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), + labels: Sequence[tuple[str, str]] | None = None, lazy: bool = False, ) -> None: """ @@ -573,7 +574,9 @@ def __init__( as_closest_canonical: if True, load the image as closest to canonical axis format. labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. - Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + Defaults to using the ``"space"`` attribute of a metatensor, + where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))`` + otherwise (i.e. for plain tensors). lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False @@ -619,9 +622,15 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch. raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.") affine_: np.ndarray affine_np: np.ndarray + labels = self.labels if isinstance(data_array, MetaTensor): affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray) affine_ = to_affine_nd(sr, affine_np) + + # Set up "labels" such that LPS tensors are handled correctly by default + if self.labels is None and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS: + labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS + else: warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.") # default to identity @@ -640,7 +649,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch. f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]}," "please make sure the input is in the channel-first format." ) - dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) + dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels) if len(dst) < sr: raise ValueError( f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D" @@ -653,8 +662,14 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(data) # Create inverse transform orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"] - orig_axcodes = nib.orientations.aff2axcodes(orig_affine) - inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels) + labels = self.labels + + # Set up "labels" such that LPS tensors are handled correctly by default + if isinstance(data, MetaTensor) and self.labels is None and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS: + labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS + + orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels) + inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels) # Apply inverse with inverse_transform.trace_transform(False): data = inverse_transform(data) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 2b80034a07..8d1d3e147d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -550,7 +550,7 @@ def __init__( keys: KeysCollection, axcodes: str | None = None, as_closest_canonical: bool = False, - labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")), + labels: Sequence[tuple[str, str]] | None = None, allow_missing_keys: bool = False, lazy: bool = False, ) -> None: @@ -564,7 +564,9 @@ def __init__( as_closest_canonical: if True, load the image as closest to canonical axis format. labels: optional, None or sequence of (2,) sequences (2,) sequences are labels for (beginning, end) of output axis. - Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``. + Defaults to using the ``"space"`` attribute of a metatensor, + where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))`` + otherwise (i.e. for plain tensors). allow_missing_keys: don't raise exception if key is missing. lazy: a flag to indicate whether this transform should execute lazily or not. Defaults to False diff --git a/tests/transforms/test_orientation.py b/tests/transforms/test_orientation.py index fee287dd5b..31acc18c5d 100644 --- a/tests/transforms/test_orientation.py +++ b/tests/transforms/test_orientation.py @@ -12,6 +12,7 @@ from __future__ import annotations import unittest +from typing import cast import nibabel as nib import numpy as np @@ -21,6 +22,7 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Orientation, create_rotate, create_translate +from monai.utils import SpaceKeys from tests.lazy_transforms_utils import test_resampler_lazy from tests.test_utils import TEST_DEVICES, assert_allclose @@ -33,6 +35,18 @@ torch.eye(4), torch.arange(12).reshape((2, 1, 2, 3)), "RAS", + False, + *device, + ] + ) + TESTS.append( + [ + {"axcodes": "LPS"}, + torch.arange(12).reshape((2, 1, 2, 3)), + torch.eye(4), + torch.arange(12).reshape((2, 1, 2, 3)), + "LPS", + True, *device, ] ) @@ -43,6 +57,18 @@ torch.as_tensor(np.diag([-1, -1, 1, 1])), torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), "ALS", + False, + *device, + ] + ) + TESTS.append( + [ + {"axcodes": "PRS"}, + torch.arange(12).reshape((2, 1, 2, 3)), + torch.as_tensor(np.diag([-1, -1, 1, 1])), + torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]), + "PRS", + True, *device, ] ) @@ -53,6 +79,18 @@ torch.as_tensor(np.diag([-1, -1, 1, 1])), torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), "RAS", + False, + *device, + ] + ) + TESTS.append( + [ + {"axcodes": "LPS"}, + torch.arange(12).reshape((2, 1, 2, 3)), + torch.as_tensor(np.diag([-1, -1, 1, 1])), + torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]), + "LPS", + True, *device, ] ) @@ -63,6 +101,18 @@ torch.eye(3), torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]), "AL", + False, + *device, + ] + ) + TESTS.append( + [ + {"axcodes": "PR"}, + torch.arange(6).reshape((2, 1, 3)), + torch.eye(3), + torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]), + "PR", + True, *device, ] ) @@ -73,6 +123,18 @@ torch.eye(2), torch.tensor([[2, 1, 0], [5, 4, 3]]), "L", + False, + *device, + ] + ) + TESTS.append( + [ + {"axcodes": "R"}, + torch.arange(6).reshape((2, 3)), + torch.eye(2), + torch.tensor([[2, 1, 0], [5, 4, 3]]), + "R", + True, *device, ] ) @@ -83,6 +145,7 @@ torch.eye(2), torch.tensor([[2, 1, 0], [5, 4, 3]]), "L", + False, *device, ] ) @@ -93,6 +156,7 @@ torch.as_tensor(np.diag([-1, 1])), torch.arange(6).reshape((2, 3)), "L", + False, *device, ] ) @@ -107,6 +171,7 @@ ), torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]), "LPS", + False, *device, ] ) @@ -121,6 +186,7 @@ ), torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]), "RAS", + False, *device, ] ) @@ -131,6 +197,7 @@ torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])), torch.tensor([[[3, 0], [4, 1], [5, 2]]]), "RA", + False, *device, ] ) @@ -141,6 +208,7 @@ torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])), torch.tensor([[[2, 5], [1, 4], [0, 3]]]), "LP", + False, *device, ] ) @@ -151,6 +219,7 @@ torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])), torch.zeros((1, 2, 3, 4, 5)), "LPID", + False, *device, ] ) @@ -161,6 +230,7 @@ torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])), torch.zeros((1, 2, 3, 4, 5)), "RASD", + False, *device, ] ) @@ -175,6 +245,11 @@ [{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)] ] +TESTS_INVERSE = [] +for device in TEST_DEVICES: + TESTS_INVERSE.append([True, *device]) + TESTS_INVERSE.append([False, *device]) + class TestOrientationCase(unittest.TestCase): @parameterized.expand(TESTS) @@ -185,9 +260,11 @@ def test_ornt_meta( affine: torch.Tensor, expected_data: torch.Tensor, expected_code: str, + lps_convention: bool, device, ): - img = MetaTensor(img, affine=affine).to(device) + meta = {"space": SpaceKeys.LPS} if lps_convention else None + img = MetaTensor(img, affine=affine, meta=meta).to(device) ornt = Orientation(**init_param) call_param = {"data_array": img} res = ornt(**call_param) # type: ignore[arg-type] @@ -195,7 +272,8 @@ def test_ornt_meta( test_resampler_lazy(ornt, res, init_param, call_param) assert_allclose(res, expected_data.to(device)) - new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore + labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels + new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore self.assertEqual("".join(new_code), expected_code) @parameterized.expand(TESTS_TORCH) @@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor): with self.assertRaises(ValueError): Orientation(**init_param)(img) - @parameterized.expand(TEST_DEVICES) - def test_inverse(self, device): + @parameterized.expand(TESTS_INVERSE) + def test_inverse(self, lps_convention: bool, device): img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device) affine = torch.tensor( [[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu" ) - meta = {"fname": "somewhere"} + meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS} img = MetaTensor(img_t, affine=affine, meta=meta) tr = Orientation("LPS") # check that image and affine have changed - img = tr(img) + img = cast(MetaTensor, tr(img)) self.assertNotEqual(img.shape, img_t.shape) - self.assertGreater((affine - img.affine).max(), 0.5) + self.assertGreater(float((affine - img.affine).max()), 0.5) # check that with inverse, image affine are back to how they were - img = tr.inverse(img) + img = cast(MetaTensor, tr.inverse(img)) self.assertEqual(img.shape, img_t.shape) - self.assertLess((affine - img.affine).max(), 1e-2) + self.assertLess(float((affine - img.affine).max()), 1e-2) if __name__ == "__main__": diff --git a/tests/transforms/test_orientationd.py b/tests/transforms/test_orientationd.py index 3fe52b0b8a..445d139b1b 100644 --- a/tests/transforms/test_orientationd.py +++ b/tests/transforms/test_orientationd.py @@ -21,13 +21,17 @@ from monai.data.meta_obj import set_track_meta from monai.data.meta_tensor import MetaTensor from monai.transforms import Orientationd +from monai.utils import SpaceKeys from tests.lazy_transforms_utils import test_resampler_lazy from tests.test_utils import TEST_DEVICES TESTS = [] for device in TEST_DEVICES: TESTS.append( - [{"keys": "seg", "axcodes": "RAS"}, torch.ones((2, 1, 2, 3)), torch.eye(4), (2, 1, 2, 3), "RAS", *device] + [{"keys": "seg", "axcodes": "RAS"}, torch.ones((2, 1, 2, 3)), torch.eye(4), (2, 1, 2, 3), "RAS", False, *device] + ) + TESTS.append( + [{"keys": "seg", "axcodes": "RAS"}, torch.ones((2, 1, 2, 3)), torch.eye(4), (2, 1, 2, 3), "RAS", True, *device] ) # 3d TESTS.append( @@ -37,15 +41,51 @@ torch.eye(4), (2, 2, 1, 3), "PLI", + False, + *device, + ] + ) + TESTS.append( + [ + {"keys": ["img", "seg"], "axcodes": "PLI"}, + torch.ones((2, 1, 2, 3)), + torch.eye(4), + (2, 2, 1, 3), + "PLI", + True, *device, ] ) # 2d TESTS.append( - [{"keys": ["img", "seg"], "axcodes": "PLI"}, torch.ones((2, 1, 3)), torch.eye(4), (2, 3, 1), "PLS", *device] + [ + {"keys": ["img", "seg"], "axcodes": "PLI"}, + torch.ones((2, 1, 3)), + torch.eye(4), + (2, 3, 1), + "PLS", + False, + *device, + ] + ) + TESTS.append( + [ + {"keys": ["img", "seg"], "axcodes": "PLI"}, + torch.ones((2, 1, 3)), + torch.eye(4), + (2, 3, 1), + "PLS", + True, + *device, + ] ) # 1d - TESTS.append([{"keys": ["img", "seg"], "axcodes": "L"}, torch.ones((2, 3)), torch.eye(4), (2, 3), "LAS", *device]) + TESTS.append( + [{"keys": ["img", "seg"], "axcodes": "L"}, torch.ones((2, 3)), torch.eye(4), (2, 3), "LAS", False, *device] + ) + TESTS.append( + [{"keys": ["img", "seg"], "axcodes": "L"}, torch.ones((2, 3)), torch.eye(4), (2, 3), "LPS", True, *device] + ) # canonical TESTS.append( [ @@ -54,6 +94,7 @@ torch.eye(4), (2, 1, 2, 3), "RAS", + False, *device, ] ) @@ -67,11 +108,19 @@ class TestOrientationdCase(unittest.TestCase): @parameterized.expand(TESTS) def test_orntd( - self, init_param, img: torch.Tensor, affine: torch.Tensor | None, expected_shape, expected_code, device + self, + init_param, + img: torch.Tensor, + affine: torch.Tensor | None, + expected_shape, + expected_code, + lps_convention: bool, + device, ): ornt = Orientationd(**init_param) if affine is not None: - img = MetaTensor(img, affine=affine) + meta = {"space": SpaceKeys.LPS} if lps_convention else None + img = MetaTensor(img, affine=affine, meta=meta) img = img.to(device) call_param = {"data": {k: img.clone() for k in ornt.keys}} res = ornt(**call_param) # type: ignore[arg-type] @@ -81,7 +130,8 @@ def test_orntd( _im = res[k] self.assertIsInstance(_im, MetaTensor) np.testing.assert_allclose(_im.shape, expected_shape) - code = nib.aff2axcodes(_im.affine.cpu(), ornt.ornt_transform.labels) # type: ignore + labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.ornt_transform.labels + code = nib.aff2axcodes(_im.affine.cpu(), labels) # type: ignore self.assertEqual("".join(code), expected_code) @parameterized.expand(TESTS_TORCH)