Skip to content

Improve Orientation transform to use the "space" (LPS vs RAS) of a metatensor by default #8473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
GridSamplePadMode,
InterpolateMode,
NumpyPadMode,
SpaceKeys,
convert_to_cupy,
convert_to_dst_type,
convert_to_numpy,
Expand Down Expand Up @@ -560,7 +561,7 @@ def __init__(
self,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
lazy: bool = False,
) -> None:
"""
Expand All @@ -573,7 +574,9 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
Defaults to using the ``"space"`` attribute of a metatensor,
where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
otherwise (i.e. for plain tensors).
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False

Expand Down Expand Up @@ -619,9 +622,15 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
affine_: np.ndarray
affine_np: np.ndarray
labels = self.labels
if isinstance(data_array, MetaTensor):
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
affine_ = to_affine_nd(sr, affine_np)

# Set up "labels" such that LPS tensors are handled correctly by default
if self.labels is None and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS:
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

else:
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
# default to identity
Expand All @@ -640,7 +649,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
"please make sure the input is in the channel-first format."
)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
if len(dst) < sr:
raise ValueError(
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
Expand All @@ -653,8 +662,14 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
# Create inverse transform
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
labels = self.labels

# Set up "labels" such that LPS tensors are handled correctly by default
if isinstance(data, MetaTensor) and self.labels is None and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS:
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS

orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
# Apply inverse
with inverse_transform.trace_transform(False):
data = inverse_transform(data)
Expand Down
6 changes: 4 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def __init__(
keys: KeysCollection,
axcodes: str | None = None,
as_closest_canonical: bool = False,
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
labels: Sequence[tuple[str, str]] | None = None,
allow_missing_keys: bool = False,
lazy: bool = False,
) -> None:
Expand All @@ -564,7 +564,9 @@ def __init__(
as_closest_canonical: if True, load the image as closest to canonical axis format.
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
Defaults to using the ``"space"`` attribute of a metatensor,
where appliable, or (('L', 'R'), ('P', 'A'), ('I', 'S'))``
otherwise (i.e. for plain tensors).
allow_missing_keys: don't raise exception if key is missing.
lazy: a flag to indicate whether this transform should execute lazily or not.
Defaults to False
Expand Down
96 changes: 87 additions & 9 deletions tests/transforms/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import unittest
from typing import cast

import nibabel as nib
import numpy as np
Expand All @@ -21,6 +22,7 @@
from monai.data.meta_obj import set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.transforms import Orientation, create_rotate, create_translate
from monai.utils import SpaceKeys
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.test_utils import TEST_DEVICES, assert_allclose

Expand All @@ -33,6 +35,18 @@
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.eye(4),
torch.arange(12).reshape((2, 1, 2, 3)),
"LPS",
True,
*device,
]
)
Expand All @@ -43,6 +57,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"ALS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PRS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
"PRS",
True,
*device,
]
)
Expand All @@ -53,6 +79,18 @@
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"RAS",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "LPS"},
torch.arange(12).reshape((2, 1, 2, 3)),
torch.as_tensor(np.diag([-1, -1, 1, 1])),
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
"LPS",
True,
*device,
]
)
Expand All @@ -63,6 +101,18 @@
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"AL",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "PR"},
torch.arange(6).reshape((2, 1, 3)),
torch.eye(3),
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
"PR",
True,
*device,
]
)
Expand All @@ -73,6 +123,18 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
TESTS.append(
[
{"axcodes": "R"},
torch.arange(6).reshape((2, 3)),
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"R",
True,
*device,
]
)
Expand All @@ -83,6 +145,7 @@
torch.eye(2),
torch.tensor([[2, 1, 0], [5, 4, 3]]),
"L",
False,
*device,
]
)
Expand All @@ -93,6 +156,7 @@
torch.as_tensor(np.diag([-1, 1])),
torch.arange(6).reshape((2, 3)),
"L",
False,
*device,
]
)
Expand All @@ -107,6 +171,7 @@
),
torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),
"LPS",
False,
*device,
]
)
Expand All @@ -121,6 +186,7 @@
),
torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),
"RAS",
False,
*device,
]
)
Expand All @@ -131,6 +197,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[3, 0], [4, 1], [5, 2]]]),
"RA",
False,
*device,
]
)
Expand All @@ -141,6 +208,7 @@
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
torch.tensor([[[2, 5], [1, 4], [0, 3]]]),
"LP",
False,
*device,
]
)
Expand All @@ -151,6 +219,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"LPID",
False,
*device,
]
)
Expand All @@ -161,6 +230,7 @@
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
torch.zeros((1, 2, 3, 4, 5)),
"RASD",
False,
*device,
]
)
Expand All @@ -175,6 +245,11 @@
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
]

TESTS_INVERSE = []
for device in TEST_DEVICES:
TESTS_INVERSE.append([True, *device])
TESTS_INVERSE.append([False, *device])


class TestOrientationCase(unittest.TestCase):
@parameterized.expand(TESTS)
Expand All @@ -185,17 +260,20 @@ def test_ornt_meta(
affine: torch.Tensor,
expected_data: torch.Tensor,
expected_code: str,
lps_convention: bool,
device,
):
img = MetaTensor(img, affine=affine).to(device)
meta = {"space": SpaceKeys.LPS} if lps_convention else None
img = MetaTensor(img, affine=affine, meta=meta).to(device)
ornt = Orientation(**init_param)
call_param = {"data_array": img}
res = ornt(**call_param) # type: ignore[arg-type]
if img.ndim in (3, 4):
test_resampler_lazy(ornt, res, init_param, call_param)

assert_allclose(res, expected_data.to(device))
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore
labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore
self.assertEqual("".join(new_code), expected_code)

@parameterized.expand(TESTS_TORCH)
Expand Down Expand Up @@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
with self.assertRaises(ValueError):
Orientation(**init_param)(img)

@parameterized.expand(TEST_DEVICES)
def test_inverse(self, device):
@parameterized.expand(TESTS_INVERSE)
def test_inverse(self, lps_convention: bool, device):
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
affine = torch.tensor(
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
)
meta = {"fname": "somewhere"}
meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}
img = MetaTensor(img_t, affine=affine, meta=meta)
tr = Orientation("LPS")
# check that image and affine have changed
img = tr(img)
img = cast(MetaTensor, tr(img))
self.assertNotEqual(img.shape, img_t.shape)
self.assertGreater((affine - img.affine).max(), 0.5)
self.assertGreater(float((affine - img.affine).max()), 0.5)
# check that with inverse, image affine are back to how they were
img = tr.inverse(img)
img = cast(MetaTensor, tr.inverse(img))
self.assertEqual(img.shape, img_t.shape)
self.assertLess((affine - img.affine).max(), 1e-2)
self.assertLess(float((affine - img.affine).max()), 1e-2)


if __name__ == "__main__":
Expand Down
Loading
Loading