From 47b16cc132cb333536e42f717dec2e236acd79ab Mon Sep 17 00:00:00 2001 From: Adama Sorho Date: Mon, 21 Apr 2025 17:10:58 -0500 Subject: [PATCH 1/2] Enhance ExtractDataKeyFromMetaKeyd to work with MetaTensor Signed-off-by: Adama Sorho --- .../reconstruction/transforms/dictionary.py | 15 ++++++- .../test_extract_data_key_from_meta_keyd.py | 39 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index c166740768..cddbefc9d1 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -20,6 +20,7 @@ from monai.apps.reconstruction.transforms.array import EquispacedKspaceMask, RandomKspaceMask from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor +from monai.data import MetaTensor from monai.transforms import InvertibleTransform from monai.transforms.croppad.array import SpatialCrop from monai.transforms.intensity.array import NormalizeIntensity @@ -57,15 +58,25 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T Returns: the new data dictionary """ + d = dict(data) + + if isinstance(d[self.meta_key], MetaTensor): + # meta tensor + meta = d[self.meta_key].meta + else: + # meta dict + meta = d[self.meta_key] + for key in self.keys: - if key in d[self.meta_key]: - d[key] = d[self.meta_key][key] # type: ignore + if key in meta: + d[key] = meta[key] # type: ignore elif not self.allow_missing_keys: raise KeyError( f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the meta data" " and allow_missing_keys==False." ) + return d # type: ignore diff --git a/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py b/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py new file mode 100644 index 0000000000..56d62bab3f --- /dev/null +++ b/tests/apps/reconstruction/transforms/test_extract_data_key_from_meta_keyd.py @@ -0,0 +1,39 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from monai.apps.reconstruction.transforms.dictionary import ExtractDataKeyFromMetaKeyd +from monai.data import MetaTensor + + +class TestExtractDataKeyFromMetaKeyd(unittest.TestCase): + def test_extract_data_key_from_dic(self): + data = {"image_data": MetaTensor([1, 2, 3]), "foo_meta_dict": {"filename_or_obj": "test_image.nii.gz"}} + + extract = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="foo_meta_dict") + result = extract(data) + + assert data["foo_meta_dict"]["filename_or_obj"] == result["filename_or_obj"] + + def test_extract_data_key_from_meta_tensor(self): + data = {"image_data": MetaTensor([1, 2, 3], meta={"filename_or_obj": 1})} + + extract = ExtractDataKeyFromMetaKeyd("filename_or_obj", meta_key="image_data") + result = extract(data) + + assert data["image_data"].meta["filename_or_obj"] == result["filename_or_obj"] + + +if __name__ == "__main__": + unittest.main() From b8ac64e8a39b7ca7c3c7d757428bd202784591a4 Mon Sep 17 00:00:00 2001 From: Adama Sorho Date: Tue, 22 Apr 2025 15:21:28 -0500 Subject: [PATCH 2/2] add type annotation to meta variable --- monai/apps/reconstruction/transforms/dictionary.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/monai/apps/reconstruction/transforms/dictionary.py b/monai/apps/reconstruction/transforms/dictionary.py index cddbefc9d1..518049f8e0 100644 --- a/monai/apps/reconstruction/transforms/dictionary.py +++ b/monai/apps/reconstruction/transforms/dictionary.py @@ -12,6 +12,7 @@ from __future__ import annotations from collections.abc import Hashable, Mapping, Sequence +from typing import Any import numpy as np from numpy import ndarray @@ -61,9 +62,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, T d = dict(data) + meta: dict[str, Any] if isinstance(d[self.meta_key], MetaTensor): # meta tensor - meta = d[self.meta_key].meta + meta = d[self.meta_key].meta # type: ignore else: # meta dict meta = d[self.meta_key]