Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update TotalSegmentator2D dataset to fetch all the slices #416

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/eva/vision/data/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dataset related function and helper functions."""

from typing import List, Tuple
from typing import List, Sequence, Tuple


def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
return ranges


def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
"""Unpacks a list of ranges to individual indices.

Args:
ranges: The list of ranges to produce the indices from.
ranges: A sequence of ranges to produce the indices from.

Return:
A list of the indices.
Expand Down
109 changes: 62 additions & 47 deletions src/eva/vision/data/datasets/segmentation/total_segmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
class TotalSegmentator2D(base.ImageSegmentation):
"""TotalSegmentator 2D segmentation dataset."""

_train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
"""Train range indices."""
_expected_dataset_lengths: Dict[str, int] = {
"train_small": 29892,
"val_small": 6480,
}
"""Dataset version and split to the expected size."""

_val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
"""Validation range indices."""

_n_slices_per_image: int = 20
ioangatop marked this conversation as resolved.
Show resolved Hide resolved
"""The amount of slices to sample per 3D CT scan image."""
_sample_every_n_slices: int | None = None
"""The amount of slices to sub-sample per 3D CT scan image."""

_resources_full: List[structs.DownloadResource] = [
structs.DownloadResource(
Expand All @@ -49,7 +49,7 @@ def __init__(
self,
root: str,
split: Literal["train", "val"] | None,
version: Literal["small", "full"] = "small",
version: Literal["small", "full"] | None = "small",
download: bool = False,
as_uint8: bool = True,
transforms: Callable | None = None,
Expand All @@ -60,7 +60,8 @@ def __init__(
root: Path to the root directory of the dataset. The dataset will
be downloaded and extracted here, if it does not already exist.
split: Dataset split to use. If `None`, the entire dataset is used.
version: The version of the dataset to initialize.
version: The version of the dataset to initialize. If `None`, it will
use the files located at root as is and wont perform any checks.
download: Whether to download the data for the specified split.
Note that the download will be executed only by additionally
calling the :meth:`prepare_data` method and if the data does not
Expand All @@ -78,7 +79,7 @@ def __init__(
self._as_uint8 = as_uint8

self._samples_dirs: List[str] = []
self._indices: List[int] = []
self._indices: List[Tuple[int, int]] = []

@functools.cached_property
@override
Expand All @@ -99,7 +100,8 @@ def class_to_idx(self) -> Dict[str, int]:

@override
def filename(self, index: int) -> str:
sample_dir = self._samples_dirs[self._indices[index]]
sample_idx, _ = self._indices[index]
sample_dir = self._samples_dirs[sample_idx]
return os.path.join(sample_dir, "ct.nii.gz")

@override
Expand All @@ -114,21 +116,24 @@ def configure(self) -> None:

@override
def validate(self) -> None:
if self._version is None:
return

_validators.check_dataset_integrity(
self,
length=1660 if self._split == "train" else 400,
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
n_classes=117,
first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
)

@override
def __len__(self) -> int:
return len(self._indices) * self._n_slices_per_image
return len(self._indices)

@override
def load_image(self, index: int) -> tv_tensors.Image:
image_path = self._get_image_path(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
image_path = self._get_image_path(sample_index)
image_array = io.read_nifti_slice(image_path, slice_index)
if self._as_uint8:
image_array = convert.to_8bit(image_array)
Expand All @@ -137,8 +142,8 @@ def load_image(self, index: int) -> tv_tensors.Image:

@override
def load_mask(self, index: int) -> tv_tensors.Mask:
masks_dir = self._get_masks_dir(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
masks_dir = self._get_masks_dir(sample_index)
mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
one_hot_encoded = np.concatenate(
[io.read_nifti_slice(path, slice_index) for path in mask_paths],
Expand All @@ -149,27 +154,20 @@ def load_mask(self, index: int) -> tv_tensors.Mask:
segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2)
return tv_tensors.Mask(segmentation_label)

def _get_masks_dir(self, index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._get_sample_dir(index)
return os.path.join(self._root, sample_dir, "segmentations")

def _get_image_path(self, index: int) -> str:
def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._get_sample_dir(index)
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "ct.nii.gz")

def _get_sample_dir(self, index: int) -> str:
"""Returns the corresponding sample directory."""
sample_index = self._indices[index // self._n_slices_per_image]
return self._samples_dirs[sample_index]
def _get_masks_dir(self, sample_index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "segmentations")

def _get_sample_slice_index(self, index: int) -> int:
"""Returns the corresponding slice index."""
image_path = self._get_image_path(index)
total_slices = io.fetch_total_nifti_slices(image_path)
slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
return slice_indices[index % self._n_slices_per_image]
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
"""Returns the total amount of slices of a sample."""
image_path = self._get_image_path(sample_index)
return io.fetch_total_nifti_slices(image_path)

def _fetch_samples_dirs(self) -> List[str]:
"""Returns the name of all the samples of all the splits of the dataset."""
Expand All @@ -180,29 +178,46 @@ def _fetch_samples_dirs(self) -> List[str]:
]
return sorted(sample_filenames)

def _create_indices(self) -> List[int]:
"""Builds the dataset indices for the specified split."""
split_index_ranges = {
"train": self._train_index_ranges,
"val": self._val_index_ranges,
None: [(0, 103)],
}
index_ranges = split_index_ranges.get(self._split)
if index_ranges is None:
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
def _get_split_indices(self) -> List[int]:
"""Returns the samples indices that corresponding the dataset split and version."""
key = f"{self._split}_{self._version}"
match key:
case "train_small":
index_ranges = [(0, 83)]
case "val_small":
index_ranges = [(83, 102)]
case _:
index_ranges = [(0, len(self._samples_dirs))]

return _utils.ranges_to_indices(index_ranges)

def _create_indices(self) -> List[Tuple[int, int]]:
"""Builds the dataset indices for the specified split.

Returns:
A list of tuples, where the first value indicates the
sample index which the second its corresponding slice
index.
"""
indices = [
(sample_idx, slide_idx)
for sample_idx in self._get_split_indices()
for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx))
if slide_idx % (self._sample_every_n_slices or 1) == 0
]
return indices

def _download_dataset(self) -> None:
"""Downloads the dataset."""
dataset_resources = {
"small": self._resources_small,
"full": self._resources_full,
None: (0, 103),
}
resources = dataset_resources.get(self._version)
resources = dataset_resources.get(self._version or "")
if resources is None:
raise ValueError("Invalid data version. Use 'small' or 'full'.")
raise ValueError(
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
)

for resource in resources:
if os.path.isdir(self._root):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize(
"split, expected_length",
[("train", 1660), ("val", 400), (None, 2060)],
[("train", 9), ("val", 9), (None, 9)],
)
def test_length(
total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int
Expand All @@ -25,6 +25,7 @@ def test_length(
[
(None, 0),
("train", 0),
("val", 0),
],
)
def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: int) -> None:
Expand All @@ -43,7 +44,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i

@pytest.fixture(scope="function")
def total_segmentator_dataset(
split: Literal["train", "val"], assets_path: str
split: Literal["train", "val"] | None, assets_path: str
) -> datasets.TotalSegmentator2D:
"""TotalSegmentator2D dataset fixture."""
dataset = datasets.TotalSegmentator2D(
Expand All @@ -55,6 +56,7 @@ def total_segmentator_dataset(
"Totalsegmentator_dataset_v201",
),
split=split,
version=None,
)
dataset.prepare_data()
dataset.configure()
Expand Down