From 1f530bf77a7c016615e837bf0fa3ed44ed31fc65 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 13:04:07 +0200 Subject: [PATCH 1/5] Update dataset to fetch all the slices --- src/eva/vision/data/datasets/_utils.py | 6 +- .../segmentation/total_segmentator.py | 107 ++++++++++-------- .../segmentation/test_total_segmentator.py | 6 +- 3 files changed, 66 insertions(+), 53 deletions(-) diff --git a/src/eva/vision/data/datasets/_utils.py b/src/eva/vision/data/datasets/_utils.py index 2d2fe30b..1a17d7e9 100644 --- a/src/eva/vision/data/datasets/_utils.py +++ b/src/eva/vision/data/datasets/_utils.py @@ -1,6 +1,6 @@ """Dataset related function and helper functions.""" -from typing import List, Tuple +from typing import List, Sequence, Tuple def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]: @@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]: return ranges -def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]: +def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]: """Unpacks a list of ranges to individual indices. Args: - ranges: The list of ranges to produce the indices from. + ranges: A sequence of ranges to produce the indices from. Return: A list of the indices. diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 4892e6b6..64b1beba 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -18,14 +18,11 @@ class TotalSegmentator2D(base.ImageSegmentation): """TotalSegmentator 2D segmentation dataset.""" - _train_index_ranges: List[Tuple[int, int]] = [(0, 83)] - """Train range indices.""" - - _val_index_ranges: List[Tuple[int, int]] = [(83, 103)] - """Validation range indices.""" - - _n_slices_per_image: int = 20 - """The amount of slices to sample per 3D CT scan image.""" + _expected_dataset_lengths: Dict[str, int] = { + "train_small": 29892, + "val_small": 6480, + } + """Dataset version and split to the expected size.""" _resources_full: List[structs.DownloadResource] = [ structs.DownloadResource( @@ -49,7 +46,7 @@ def __init__( self, root: str, split: Literal["train", "val"] | None, - version: Literal["small", "full"] = "small", + version: Literal["small", "full"] | None = "small", download: bool = False, as_uint8: bool = True, transforms: Callable | None = None, @@ -60,7 +57,8 @@ def __init__( root: Path to the root directory of the dataset. The dataset will be downloaded and extracted here, if it does not already exist. split: Dataset split to use. If `None`, the entire dataset is used. - version: The version of the dataset to initialize. + version: The version of the dataset to initialize. If `None`, it will + use the files located at root as is and wont perform any checks. download: Whether to download the data for the specified split. Note that the download will be executed only by additionally calling the :meth:`prepare_data` method and if the data does not @@ -78,7 +76,7 @@ def __init__( self._as_uint8 = as_uint8 self._samples_dirs: List[str] = [] - self._indices: List[int] = [] + self._indices: List[Tuple[int, int]] = [] @functools.cached_property @override @@ -99,7 +97,8 @@ def class_to_idx(self) -> Dict[str, int]: @override def filename(self, index: int) -> str: - sample_dir = self._samples_dirs[self._indices[index]] + sample_idx, _ = self._indices[index] + sample_dir = self._samples_dirs[sample_idx] return os.path.join(sample_dir, "ct.nii.gz") @override @@ -114,21 +113,24 @@ def configure(self) -> None: @override def validate(self) -> None: + if self._version is None: + return + _validators.check_dataset_integrity( self, - length=1660 if self._split == "train" else 400, + length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), n_classes=117, first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"), ) @override def __len__(self) -> int: - return len(self._indices) * self._n_slices_per_image + return len(self._indices) @override def load_image(self, index: int) -> tv_tensors.Image: - image_path = self._get_image_path(index) - slice_index = self._get_sample_slice_index(index) + sample_index, slice_index = self._indices[index] + image_path = self._get_image_path(sample_index) image_array = io.read_nifti_slice(image_path, slice_index) if self._as_uint8: image_array = convert.to_8bit(image_array) @@ -137,8 +139,8 @@ def load_image(self, index: int) -> tv_tensors.Image: @override def load_mask(self, index: int) -> tv_tensors.Mask: - masks_dir = self._get_masks_dir(index) - slice_index = self._get_sample_slice_index(index) + sample_index, slice_index = self._indices[index] + masks_dir = self._get_masks_dir(sample_index) mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes) one_hot_encoded = np.concatenate( [io.read_nifti_slice(path, slice_index) for path in mask_paths], @@ -149,27 +151,20 @@ def load_mask(self, index: int) -> tv_tensors.Mask: segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2) return tv_tensors.Mask(segmentation_label) - def _get_masks_dir(self, index: int) -> str: - """Returns the directory of the corresponding masks.""" - sample_dir = self._get_sample_dir(index) - return os.path.join(self._root, sample_dir, "segmentations") - - def _get_image_path(self, index: int) -> str: + def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" - sample_dir = self._get_sample_dir(index) + sample_dir = self._samples_dirs[sample_index] return os.path.join(self._root, sample_dir, "ct.nii.gz") - def _get_sample_dir(self, index: int) -> str: - """Returns the corresponding sample directory.""" - sample_index = self._indices[index // self._n_slices_per_image] - return self._samples_dirs[sample_index] + def _get_masks_dir(self, sample_index: int) -> str: + """Returns the directory of the corresponding masks.""" + sample_dir = self._samples_dirs[sample_index] + return os.path.join(self._root, sample_dir, "segmentations") - def _get_sample_slice_index(self, index: int) -> int: - """Returns the corresponding slice index.""" - image_path = self._get_image_path(index) - total_slices = io.fetch_total_nifti_slices(image_path) - slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int) - return slice_indices[index % self._n_slices_per_image] + def _get_sample_total_slices(self, sample_index: int) -> int: + """Returns the total amount of slices of a sample.""" + image_path = self._get_image_path(sample_index) + return io.fetch_total_nifti_slices(image_path) def _fetch_samples_dirs(self) -> List[str]: """Returns the name of all the samples of all the splits of the dataset.""" @@ -180,29 +175,45 @@ def _fetch_samples_dirs(self) -> List[str]: ] return sorted(sample_filenames) - def _create_indices(self) -> List[int]: - """Builds the dataset indices for the specified split.""" - split_index_ranges = { - "train": self._train_index_ranges, - "val": self._val_index_ranges, - None: [(0, 103)], - } - index_ranges = split_index_ranges.get(self._split) - if index_ranges is None: - raise ValueError("Invalid data split. Use 'train', 'val' or `None`.") + def _get_split_indices(self) -> List[int]: + """Returns the samples indices that corresponding the dataset split and version.""" + key = f"{self._split}_{self._version}" + match key: + case "train_small": + index_ranges = [(0, 83)] + case "val_small": + index_ranges = [(83, 102)] + case _: + index_ranges = [(0, len(self._samples_dirs))] return _utils.ranges_to_indices(index_ranges) + def _create_indices(self) -> List[Tuple[int, int]]: + """Builds the dataset indices for the specified split. + + Returns: + A list of tuples, where the first value indicates the + sample index which the second its corresponding slice + index. + """ + indices = [ + (sample_idx, slide_idx) + for sample_idx in self._get_split_indices() + for slide_idx in range(self._get_sample_total_slices(sample_idx)) + ] + return indices + def _download_dataset(self) -> None: """Downloads the dataset.""" dataset_resources = { "small": self._resources_small, "full": self._resources_full, - None: (0, 103), } - resources = dataset_resources.get(self._version) + resources = dataset_resources.get(self._version or "") if resources is None: - raise ValueError("Invalid data version. Use 'small' or 'full'.") + raise ValueError( + f"Can't download data version '{self._version}'. Use 'small' or 'full'." + ) for resource in resources: if os.path.isdir(self._root): diff --git a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py index 3e7f09e6..9607a2a8 100644 --- a/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py +++ b/tests/eva/vision/data/datasets/segmentation/test_total_segmentator.py @@ -11,7 +11,7 @@ @pytest.mark.parametrize( "split, expected_length", - [("train", 1660), ("val", 400), (None, 2060)], + [("train", 9), ("val", 9), (None, 9)], ) def test_length( total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int @@ -25,6 +25,7 @@ def test_length( [ (None, 0), ("train", 0), + ("val", 0), ], ) def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: int) -> None: @@ -43,7 +44,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i @pytest.fixture(scope="function") def total_segmentator_dataset( - split: Literal["train", "val"], assets_path: str + split: Literal["train", "val"] | None, assets_path: str ) -> datasets.TotalSegmentator2D: """TotalSegmentator2D dataset fixture.""" dataset = datasets.TotalSegmentator2D( @@ -55,6 +56,7 @@ def total_segmentator_dataset( "Totalsegmentator_dataset_v201", ), split=split, + version=None, ) dataset.prepare_data() dataset.configure() From bab039be14f63669c0a7ec1dc138464011ae337d Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 17:52:05 +0200 Subject: [PATCH 2/5] rename method --- .../vision/data/datasets/segmentation/total_segmentator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 64b1beba..661aa33a 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -161,7 +161,7 @@ def _get_masks_dir(self, sample_index: int) -> str: sample_dir = self._samples_dirs[sample_index] return os.path.join(self._root, sample_dir, "segmentations") - def _get_sample_total_slices(self, sample_index: int) -> int: + def _get_number_of_slices_per_sample(self, sample_index: int) -> int: """Returns the total amount of slices of a sample.""" image_path = self._get_image_path(sample_index) return io.fetch_total_nifti_slices(image_path) @@ -199,7 +199,7 @@ def _create_indices(self) -> List[Tuple[int, int]]: indices = [ (sample_idx, slide_idx) for sample_idx in self._get_split_indices() - for slide_idx in range(self._get_sample_total_slices(sample_idx)) + for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx)) ] return indices From a6683394bd48651af20fa6131cdea33e6e157b20 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 18:04:54 +0200 Subject: [PATCH 3/5] add `_n_slices_per_image` option --- .../vision/data/datasets/segmentation/total_segmentator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index 661aa33a..ae4e1f4e 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -24,6 +24,9 @@ class TotalSegmentator2D(base.ImageSegmentation): } """Dataset version and split to the expected size.""" + _n_slices_per_image: int | None = None + """The amount of slices to sample per 3D CT scan image.""" + _resources_full: List[structs.DownloadResource] = [ structs.DownloadResource( filename="Totalsegmentator_dataset_v201.zip", @@ -200,6 +203,7 @@ def _create_indices(self) -> List[Tuple[int, int]]: (sample_idx, slide_idx) for sample_idx in self._get_split_indices() for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx)) + if slide_idx % self._n_slices_per_image or 1 == 0 ] return indices From 5725e8f912965365609419ed51fb2ea1e7e0167f Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 18:07:00 +0200 Subject: [PATCH 4/5] fix typing --- src/eva/vision/data/datasets/segmentation/total_segmentator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index ae4e1f4e..dc083f48 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -203,7 +203,7 @@ def _create_indices(self) -> List[Tuple[int, int]]: (sample_idx, slide_idx) for sample_idx in self._get_split_indices() for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx)) - if slide_idx % self._n_slices_per_image or 1 == 0 + if slide_idx % (self._n_slices_per_image or 1) == 0 ] return indices From 983a361b0ed549af38574a54d84ec7e4dc89b535 Mon Sep 17 00:00:00 2001 From: ioangatop Date: Mon, 6 May 2024 18:08:57 +0200 Subject: [PATCH 5/5] rename to `_sample_every_n_slices` --- .../vision/data/datasets/segmentation/total_segmentator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator.py b/src/eva/vision/data/datasets/segmentation/total_segmentator.py index dc083f48..92bb8992 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator.py @@ -24,8 +24,8 @@ class TotalSegmentator2D(base.ImageSegmentation): } """Dataset version and split to the expected size.""" - _n_slices_per_image: int | None = None - """The amount of slices to sample per 3D CT scan image.""" + _sample_every_n_slices: int | None = None + """The amount of slices to sub-sample per 3D CT scan image.""" _resources_full: List[structs.DownloadResource] = [ structs.DownloadResource( @@ -203,7 +203,7 @@ def _create_indices(self) -> List[Tuple[int, int]]: (sample_idx, slide_idx) for sample_idx in self._get_split_indices() for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx)) - if slide_idx % (self._n_slices_per_image or 1) == 0 + if slide_idx % (self._sample_every_n_slices or 1) == 0 ] return indices