diff --git a/.gitattributes b/.gitattributes index 29bcbf6f3..d4a50a98e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -9,3 +9,4 @@ tests/eva/assets/**/*.npy filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.xml filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.mat filter=lfs diff=lfs merge=lfs -text tests/eva/assets/**/*.nii filter=lfs diff=lfs merge=lfs -text +tests/eva/assets/**/*.nii.gz filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index cd5856585..a9042f9b5 100644 --- a/.gitignore +++ b/.gitignore @@ -179,6 +179,3 @@ cython_debug/ # numpy data *.npy - -# NiFti data -*.nii.gz diff --git a/configs/vision/radiology/offline/segmentation/kits23.yaml b/configs/vision/radiology/offline/segmentation/kits23.yaml new file mode 100644 index 000000000..103c2f033 --- /dev/null +++ b/configs/vision/radiology/offline/segmentation/kits23.yaml @@ -0,0 +1,158 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + log_images: false + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 5} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + - class_path: eva.callbacks.SegmentationEmbeddingsWriter + init_args: + output_dir: &DATASET_EMBEDDINGS_ROOT ${oc.env:EMBEDDINGS_ROOT, ./data/embeddings}/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23 + dataloader_idx_map: + 0: train + 1: val + 2: test + metadata_keys: ["slice_index"] + overwrite: false + backbone: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderMS + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 4 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: eva.vision.metrics.MonaiDiceScore + init_args: + include_background: true + num_classes: *NUM_CLASSES + reduction: none + labels: + - background + - kidney + - tumor + - cyst +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: &DATASET_ARGS + root: *DATASET_EMBEDDINGS_ROOT + manifest_file: manifest.csv + split: train + val: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.EmbeddingsSegmentationDataset + init_args: + <<: *DATASET_ARGS + split: test + predict: + - class_path: eva.vision.datasets.KiTS23 + init_args: &PREDICT_DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/kits23} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset automatically from the official source. + # The KiTS23 dataset is distributed under the following license: + # "Attribution-NonCommercial-ShareAlike 4.0 International" + # (see: https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + size: ${oc.env:RESIZE_DIM, 224} + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *PREDICT_DATASET_ARGS + split: val + - class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *PREDICT_DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + shuffle: true + test: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + predict: + batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 64} + num_workers: *N_DATA_WORKERS diff --git a/configs/vision/radiology/online/segmentation/kits23.yaml b/configs/vision/radiology/online/segmentation/kits23.yaml new file mode 100644 index 000000000..d804e864c --- /dev/null +++ b/configs/vision/radiology/online/segmentation/kits23.yaml @@ -0,0 +1,132 @@ +--- +trainer: + class_path: eva.Trainer + init_args: + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:MODEL_NAME, vit_small_patch16_224_dino}/kits23} + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 40000} + checkpoint_type: ${oc.env:CHECKPOINT_TYPE, best} + callbacks: + - class_path: eva.callbacks.ConfigurationLogger + - class_path: lightning.pytorch.callbacks.TQDMProgressBar + init_args: + refresh_rate: ${oc.env:TQDM_REFRESH_RATE, 1} + - class_path: eva.vision.callbacks.SemanticSegmentationLogger + init_args: + log_every_n_epochs: 1 + mean: &NORMALIZE_MEAN ${oc.env:NORMALIZE_MEAN, [0.485, 0.456, 0.406]} + std: &NORMALIZE_STD ${oc.env:NORMALIZE_STD, [0.229, 0.224, 0.225]} + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + filename: best + save_last: true + save_top_k: 1 + monitor: &MONITOR_METRIC ${oc.env:MONITOR_METRIC, 'val/MonaiDiceScore'} + mode: &MONITOR_METRIC_MODE ${oc.env:MONITOR_METRIC_MODE, max} + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + min_delta: 0 + patience: ${oc.env:PATIENCE, 5} + monitor: *MONITOR_METRIC + mode: *MONITOR_METRIC_MODE + logger: + - class_path: lightning.pytorch.loggers.TensorBoardLogger + init_args: + save_dir: *OUTPUT_ROOT + name: "" +model: + class_path: eva.vision.models.modules.SemanticSegmentationModule + init_args: + encoder: + class_path: eva.vision.models.ModelFromRegistry + init_args: + model_name: ${oc.env:MODEL_NAME, universal/vit_small_patch16_224_dino} + model_kwargs: + out_indices: ${oc.env:OUT_INDICES, 1} + model_extra_kwargs: ${oc.env:MODEL_EXTRA_KWARGS, null} + decoder: + class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoderWithImage + init_args: + in_features: ${oc.env:IN_FEATURES, 384} + num_classes: &NUM_CLASSES 4 + criterion: + class_path: eva.vision.losses.DiceLoss + init_args: + softmax: true + batch: true + lr_multiplier_encoder: 0.0 + optimizer: + class_path: torch.optim.AdamW + init_args: + lr: ${oc.env:LR_VALUE, 0.002} + lr_scheduler: + class_path: torch.optim.lr_scheduler.PolynomialLR + init_args: + total_iters: *MAX_STEPS + power: 0.9 + postprocess: + predictions_transforms: + - class_path: torch.argmax + init_args: + dim: 1 + metrics: + common: + - class_path: eva.metrics.AverageLoss + evaluation: + - class_path: eva.vision.metrics.defaults.MulticlassSegmentationMetrics + init_args: + num_classes: *NUM_CLASSES + - class_path: torchmetrics.ClasswiseWrapper + init_args: + metric: + class_path: eva.vision.metrics.MonaiDiceScore + init_args: + include_background: true + num_classes: *NUM_CLASSES + reduction: none + labels: + - background + - kidney + - tumor + - cyst +data: + class_path: eva.DataModule + init_args: + datasets: + train: + class_path: eva.vision.datasets.KiTS23 + init_args: &DATASET_ARGS + root: ${oc.env:DATA_ROOT, ./data/kits23} + split: train + download: ${oc.env:DOWNLOAD_DATA, false} + # Set `download: true` to download the dataset automatically from the official source. + # The KiTS23 dataset is distributed under the following license: + # "Attribution-NonCommercial-ShareAlike 4.0 International" + # (see: https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) + transforms: + class_path: eva.vision.data.transforms.common.ResizeAndClamp + init_args: + mean: *NORMALIZE_MEAN + std: *NORMALIZE_STD + val: + class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *DATASET_ARGS + split: val + test: + class_path: eva.vision.datasets.KiTS23 + init_args: + <<: *DATASET_ARGS + split: test + dataloaders: + train: + batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 64} + num_workers: &N_DATA_WORKERS ${oc.env:N_DATA_WORKERS, 4} + shuffle: true + val: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS + shuffle: true + test: + batch_size: *BATCH_SIZE + num_workers: *N_DATA_WORKERS diff --git a/src/eva/vision/data/datasets/__init__.py b/src/eva/vision/data/datasets/__init__.py index 5c31edc8d..e98eb838e 100644 --- a/src/eva/vision/data/datasets/__init__.py +++ b/src/eva/vision/data/datasets/__init__.py @@ -15,6 +15,7 @@ CoNSeP, EmbeddingsSegmentationDataset, ImageSegmentation, + KiTS23, LiTS, LiTSBalanced, MoNuSAC, @@ -36,6 +37,7 @@ "CoNSeP", "EmbeddingsSegmentationDataset", "ImageSegmentation", + "KiTS23", "LiTS", "LiTSBalanced", "MoNuSAC", diff --git a/src/eva/vision/data/datasets/segmentation/__init__.py b/src/eva/vision/data/datasets/segmentation/__init__.py index b954fa395..edabcf8a0 100644 --- a/src/eva/vision/data/datasets/segmentation/__init__.py +++ b/src/eva/vision/data/datasets/segmentation/__init__.py @@ -4,6 +4,7 @@ from eva.vision.data.datasets.segmentation.bcss import BCSS from eva.vision.data.datasets.segmentation.consep import CoNSeP from eva.vision.data.datasets.segmentation.embeddings import EmbeddingsSegmentationDataset +from eva.vision.data.datasets.segmentation.kits23 import KiTS23 from eva.vision.data.datasets.segmentation.lits import LiTS from eva.vision.data.datasets.segmentation.lits_balanced import LiTSBalanced from eva.vision.data.datasets.segmentation.monusac import MoNuSAC @@ -14,6 +15,7 @@ "BCSS", "CoNSeP", "EmbeddingsSegmentationDataset", + "KiTS23", "LiTS", "LiTSBalanced", "MoNuSAC", diff --git a/src/eva/vision/data/datasets/segmentation/kits23.py b/src/eva/vision/data/datasets/segmentation/kits23.py new file mode 100644 index 000000000..0ed351bd4 --- /dev/null +++ b/src/eva/vision/data/datasets/segmentation/kits23.py @@ -0,0 +1,276 @@ +"""KiTS23 dataset.""" + +import functools +import os +import time +from pathlib import Path +from typing import Any, Callable, Dict, List, Literal, Tuple +from urllib import request + +import nibabel as nib +import torch +from torchvision import tv_tensors +from typing_extensions import override + +from eva.core.data import splitting +from eva.core.utils import multiprocessing +from eva.core.utils.progress_bar import tqdm +from eva.vision.data.datasets import _utils, _validators +from eva.vision.data.datasets.segmentation import base +from eva.vision.utils import io +from eva.vision.utils.io import nifti + + +class KiTS23(base.ImageSegmentation): + """KiTS23 - The 2023 Kidney and Kidney Tumor Segmentation challenge. + + To optimize data loading, the dataset is preprocessed by reorienting the images + from IPL to LAS and uncompressing them. The reorientation is necessary, because + loading slices from the first dimension is significantly slower than from the last, + due to data not being stored in a contiguous manner on disk accross all dimensions. + + Webpage: https://kits-challenge.org/kits23/ + """ + + _index_ranges: List[Tuple[int, int]] = [(0, 300), (400, 589)] + """Dataset index ranges.""" + + _train_ratio: float = 0.7 + _val_ratio: float = 0.15 + _test_ratio: float = 0.15 + """Ratios for dataset splits.""" + + _expected_dataset_lengths: Dict[str | None, int] = { + "train": 67582, + "val": 13751, + "test": 13888, + None: 95221, + } + """Dataset version and split to the expected size.""" + + _sample_every_n_slices: int | None = None + """The amount of slices to sub-sample per 3D CT scan image.""" + + _processed_dir: str = "processed" + """Directory where the processed data (reoriented to LPS & uncompressed) is stored.""" + + _license: str = "CC BY-NC-SA 4.0" + """Dataset license.""" + + def __init__( + self, + root: str, + split: Literal["train", "val", "test"] | None = None, + download: bool = False, + num_workers: int = 10, + transforms: Callable | None = None, + seed: int = 8, + ) -> None: + """Initialize dataset. + + Args: + root: Path to the root directory of the dataset. The dataset will + be downloaded and extracted here, if it does not already exist. + split: Dataset split to use. If `None`, the entire dataset will be used. + download: Whether to download the data for the specified split. + Note that the download will be executed only by additionally + calling the :meth:`prepare_data` method and if the data does + not yet exist on disk. + num_workers: The number of workers to use for preprocessing the dataset. + transforms: A function/transforms that takes in an image and a target + mask and returns the transformed versions of both. + seed: Seed used for generating the dataset splits. + """ + super().__init__(transforms=transforms) + + self._root = root + self._split = split + self._download = download + self._num_workers = num_workers + self._seed = seed + + self._indices: List[Tuple[int, int]] = [] + + @property + @override + def classes(self) -> List[str]: + return ["background", "kidney", "tumor", "cyst"] + + @functools.cached_property + @override + def class_to_idx(self) -> Dict[str, int]: + return {label: index for index, label in enumerate(self.classes)} + + @property + def _processed_root(self) -> str: + return os.path.join(self._root, self._processed_dir) + + @override + def filename(self, index: int) -> str: + sample_index, _ = self._indices[index] + return self._volume_filename(sample_index) + + @override + def prepare_data(self) -> None: + if self._download: + self._download_dataset() + self._preprocess() + + @override + def configure(self) -> None: + self._indices = self._create_indices() + + @override + def validate(self) -> None: + _validators.check_dataset_integrity( + self, + length=self._expected_dataset_lengths.get(self._split, 0), + n_classes=4, + first_and_last_labels=("background", "cyst"), + ) + + @override + def load_image(self, index: int) -> tv_tensors.Image: + sample_index, slice_index = self._indices[index] + volume_path = self._volume_path(sample_index) + image_array = io.read_nifti(volume_path, slice_index) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] + + @override + def load_mask(self, index: int) -> tv_tensors.Mask: + sample_index, slice_index = self._indices[index] + segmentation_path = self._segmentation_path(sample_index) + semantic_labels = io.read_nifti(segmentation_path, slice_index) + return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + + @override + def load_metadata(self, index: int) -> Dict[str, Any]: + sample_index, slice_index = self._indices[index] + return {"case_id": f"{sample_index:05d}", "slice_index": slice_index} + + @override + def __len__(self) -> int: + return len(self._indices) + + def _create_indices(self) -> List[Tuple[int, int]]: + """Builds the dataset indices for the specified split. + + Returns: + A list of tuples, where the first value indicates the + sample index which the second its corresponding slice + index. + """ + indices = [ + (sample_idx, slide_idx) + for sample_idx in self._get_split_indices() + for slide_idx in range(self._get_number_of_slices_per_volume(sample_idx)) + if slide_idx % (self._sample_every_n_slices or 1) == 0 + ] + return indices + + def _get_split_indices(self) -> List[int]: + """Builds the dataset indices for the specified split.""" + indices = _utils.ranges_to_indices(self._index_ranges) + + train_indices, val_indices, test_indices = splitting.random_split( + indices, self._train_ratio, self._val_ratio, self._test_ratio, seed=self._seed + ) + split_indices_dict = { + "train": [indices[i] for i in train_indices], + "val": [indices[i] for i in val_indices], + "test": [indices[i] for i in test_indices], # type: ignore + None: indices, + } + if self._split not in split_indices_dict: + raise ValueError("Invalid data split. Use 'train', 'val', 'test' or `None`.") + + return list(split_indices_dict[self._split]) + + def _get_number_of_slices_per_volume(self, sample_index: int) -> int: + """Returns the total amount of slices of a volume.""" + volume_shape = io.fetch_nifti_shape(self._volume_path(sample_index)) + return volume_shape[-1] + + def _volume_filename(self, sample_index: int) -> str: + return f"case_{sample_index:05d}/master_{sample_index:05d}.nii" + + def _segmentation_filename(self, sample_index: int) -> str: + return f"case_{sample_index:05d}/segmentation.nii" + + def _volume_path(self, sample_index: int, processed: bool = True) -> str: + root = self._processed_root if processed else self._root + return os.path.join(root, self._volume_filename(sample_index)) + + def _segmentation_path(self, sample_index: int, processed: bool = True) -> str: + root = self._processed_root if processed else self._root + return os.path.join(root, self._segmentation_filename(sample_index)) + + def _download_dataset(self) -> None: + """Downloads the dataset.""" + self._print_license() + for case_id in tqdm( + self._get_split_indices(), + desc=">> Downloading dataset", + leave=False, + ): + image_path = self._volume_path(case_id, processed=False) + segmentation_path = self._segmentation_path(case_id, processed=False) + if os.path.isfile(image_path) and os.path.isfile(segmentation_path): + continue + + _download_case_with_retry(case_id, image_path, segmentation_path) + + def _preprocess(self): + """Reorienting the images to LPS and uncompressing them.""" + + def _reorient_and_save(path: Path) -> None: + relative_path = str(path.relative_to(self._root)) + save_path = os.path.join(self._processed_root, relative_path.rstrip(".gz")) + if os.path.isfile(save_path): + return + os.makedirs(os.path.dirname(save_path), exist_ok=True) + nifti.reorient(nib.load(path), "LPS").to_filename(str(save_path)) + + compressed_paths = list(Path(self._root).rglob("*.nii.gz")) + multiprocessing.run_with_threads( + _reorient_and_save, + [(path,) for path in compressed_paths], + num_workers=1, + progress_desc=">> Preprocessing dataset", + return_results=False, + ) + + processed_paths = list(Path(self._processed_root).rglob("*.nii")) + if len(compressed_paths) != len(processed_paths): + raise RuntimeError(f"Preprocessing failed, missing files in {self._processed_root}.") + + def _print_license(self) -> None: + """Prints the dataset license.""" + print(f"Dataset license: {self._license}") + + +def _download_case_with_retry( + case_id: int, + image_path: str, + segmentation_path: str, + *, + retries: int = 2, +) -> None: + for attempt in range(retries): + try: + os.makedirs(os.path.dirname(image_path), exist_ok=True) + request.urlretrieve( + url=f"https://kits19.sfo2.digitaloceanspaces.com/master_{case_id:05d}.nii.gz", # nosec + filename=image_path, + ) + request.urlretrieve( + url=f"https://raw.githubusercontent.com/neheller/kits23/e282208/dataset/case_{case_id:05d}/segmentation.nii.gz", # nosec + filename=segmentation_path, + ) + return + + except Exception as e: + if attempt < retries - 1: + time.sleep(5) + else: + raise e diff --git a/src/eva/vision/data/datasets/segmentation/lits.py b/src/eva/vision/data/datasets/segmentation/lits.py index 6add83b0a..0b88bf4af 100644 --- a/src/eva/vision/data/datasets/segmentation/lits.py +++ b/src/eva/vision/data/datasets/segmentation/lits.py @@ -27,7 +27,7 @@ class LiTS(base.ImageSegmentation): _train_ratio: float = 0.7 _val_ratio: float = 0.15 _test_ratio: float = 0.15 - """Index ranges per split.""" + """Ratios for dataset splits.""" _fix_orientation: bool = True """Whether to fix the orientation of the images to match the default for radiologists.""" @@ -116,7 +116,7 @@ def load_image(self, index: int) -> tv_tensors.Image: image_array = io.read_nifti(volume_path, slice_index) if self._fix_orientation: image_array = self._orientation(image_array, sample_index) - return tv_tensors.Image(image_array.transpose(2, 0, 1)) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 53cc0c5fd..ff8d48912 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -211,7 +211,7 @@ def load_image(self, index: int) -> tv_tensors.Image: image_path = self._get_image_path(sample_index) image_array = io.read_nifti(image_path, slice_index) image_array = self._fix_orientation(image_array) - return tv_tensors.Image(image_array.copy().transpose(2, 0, 1)) + return tv_tensors.Image(image_array.transpose(2, 0, 1), dtype=torch.float32) # type: ignore[reportCallIssue] @override def load_mask(self, index: int) -> tv_tensors.Mask: diff --git a/src/eva/vision/utils/io/nifti.py b/src/eva/vision/utils/io/nifti.py index 49ca8fdaa..47d9e2362 100644 --- a/src/eva/vision/utils/io/nifti.py +++ b/src/eva/vision/utils/io/nifti.py @@ -11,7 +11,11 @@ def read_nifti( - path: str, slice_index: int | None = None, *, use_storage_dtype: bool = True + path: str, + slice_index: int | None = None, + *, + use_storage_dtype: bool = True, + target_orientation: str | None = None, ) -> npt.NDArray[Any]: """Reads and loads a NIfTI image from a file path. @@ -20,6 +24,7 @@ def read_nifti( slice_index: Whether to read only a slice from the file. use_storage_dtype: Whether to cast the raw image array to the inferred type. + target_orientation: The target orientation to reorient the image. E.g. "LPS". Returns: The image as a numpy array (height, width, channels). @@ -30,16 +35,39 @@ def read_nifti( """ _utils.check_file(path) image_data: nib.Nifti1Image = nib.load(path) # type: ignore + if target_orientation is not None: + image_data = reorient(image_data, target_orientation) + if slice_index is not None: - image_data = image_data.slicer[:, :, slice_index : slice_index + 1] + image_array = np.expand_dims(image_data.dataobj[:, :, slice_index], -1) + else: + image_array = image_data.get_fdata() - image_array = image_data.get_fdata() if use_storage_dtype: image_array = image_array.astype(image_data.get_data_dtype()) return image_array +def reorient( + nii: nib.Nifti1Image, + orientation: str | Tuple[str, str, str] = "LPS", +) -> nib.Nifti1Image: + """Reorients a nifti image to specified orientation. + + Args: + nii: The input nifti image. + orientation: The target orientation to reorient the image. E.g. "LPS" or ("L", "P", "S"). + """ + orig_ornt = nib.io_orientation(nii.affine) + targ_ornt = orientations.axcodes2ornt(orientation) + if np.all(orig_ornt == targ_ornt): + return nii + transform = orientations.ornt_transform(orig_ornt, targ_ornt) + reoriented_nii = nii.as_reoriented(transform) + return reoriented_nii + + def save_array_as_nifti( array: npt.ArrayLike, filename: str, diff --git a/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz new file mode 100644 index 000000000..1cceb422a --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00036/master_00036.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c7218db62c033db92b9d6ed29e0925a2c8793d5206fa65fb7b47f92a5699e085 +size 1094157 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz new file mode 100644 index 000000000..90ca9d761 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00036/segmentation.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ece3061b8c97179b5ddf91a26fdbfc124276f789fe6fca942d5678396810fd2c +size 8328 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz new file mode 100644 index 000000000..f8b162d85 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00240/master_00240.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:092410c27a5c04537b6c359e336cfe97693029a2dbe3c6e5ea0892d52ba3a5d6 +size 1004338 diff --git a/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz b/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz new file mode 100644 index 000000000..90ca9d761 --- /dev/null +++ b/tests/eva/assets/vision/datasets/kits23/case_00240/segmentation.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ece3061b8c97179b5ddf91a26fdbfc124276f789fe6fca942d5678396810fd2c +size 8328 diff --git a/tests/eva/vision/data/datasets/segmentation/test_kits23.py b/tests/eva/vision/data/datasets/segmentation/test_kits23.py new file mode 100644 index 000000000..466f7a415 --- /dev/null +++ b/tests/eva/vision/data/datasets/segmentation/test_kits23.py @@ -0,0 +1,83 @@ +"""KiTS23 dataset tests.""" + +import os +import shutil +from typing import Literal +from unittest.mock import patch + +import pytest +from torchvision import tv_tensors + +from eva.vision.data import datasets + + +@pytest.mark.parametrize( + "split, expected_length", + [(None, 8)], +) +def test_length(kits23_dataset: datasets.KiTS23, expected_length: int) -> None: + """Tests the length of the dataset.""" + assert len(kits23_dataset) == expected_length + + +@pytest.mark.parametrize( + "split, index", + [ + (None, 0), + ], +) +def test_sample(kits23_dataset: datasets.KiTS23, index: int) -> None: + """Tests the format of a dataset sample.""" + # assert data sample is a tuple + sample = kits23_dataset[index] + assert isinstance(sample, tuple) + assert len(sample) == 3 + # assert the format of the `image` and `mask` + image, mask, metadata = sample + assert isinstance(image, tv_tensors.Image) + assert image.shape == (1, 512, 512) + assert isinstance(mask, tv_tensors.Mask) + assert mask.shape == (512, 512) + assert isinstance(metadata, dict) + assert "slice_index" in metadata + + +@pytest.mark.parametrize("split", [None]) +def test_processed_dir_exists(kits23_dataset: datasets.KiTS23) -> None: + """Tests the existence of the processed directory.""" + assert os.path.isdir(kits23_dataset._processed_root) + + for index in ["00036", "00240"]: + assert os.path.isfile( + os.path.join(kits23_dataset._processed_root, f"case_{index}/master_{index}.nii") + ) + assert os.path.isfile( + os.path.join(kits23_dataset._processed_root, f"case_{index}/segmentation.nii") + ) + + +@pytest.fixture(scope="function") +def kits23_dataset(split: Literal["train", "val", "test"] | None, assets_path: str): + """KiTS23 dataset fixture.""" + dataset = datasets.KiTS23( + root=os.path.join( + assets_path, + "vision", + "datasets", + "kits23", + ), + split=split, + ) + dataset.prepare_data() + dataset.configure() + yield dataset + + if os.path.isdir(dataset._processed_root): + shutil.rmtree(dataset._processed_root) + + +@pytest.fixture(autouse=True) +def mock_indices(): + """Mocks the download function to avoid downloading resources when running tests.""" + with patch.object(datasets.KiTS23, "_get_split_indices", return_value=[36, 240]): + yield diff --git a/tests/eva/vision/utils/io/test_nifti.py b/tests/eva/vision/utils/io/test_nifti.py new file mode 100644 index 000000000..10c3d2268 --- /dev/null +++ b/tests/eva/vision/utils/io/test_nifti.py @@ -0,0 +1,53 @@ +"""Tests for the nifti IO functions.""" + +import os + +import nibabel as nib +import numpy as np +import pytest +from nibabel import orientations + +from eva.vision.utils.io import nifti + + +@pytest.mark.parametrize( + "use_storage_dtype, target_orientation", + [ + [False, None], + [False, "LPS"], + [True, "RAS"], + ], +) +def test_read_nifti(nifti_path: str, use_storage_dtype: bool, target_orientation: str): + """Tests the function to read a nifti file as array (full & slice).""" + image = nifti.read_nifti( + nifti_path, use_storage_dtype=use_storage_dtype, target_orientation=target_orientation + ) + assert image.shape == (512, 512, 4) + + slice_image = nifti.read_nifti( + nifti_path, + slice_index=0, + use_storage_dtype=use_storage_dtype, + target_orientation=target_orientation, + ) + assert slice_image.shape == (512, 512, 1) + + expected_dtype = np.dtype("