Skip to content

Commit

Permalink
Merge branch '415-update-totalsegmentator2d-to-fetch-all-the-slides' …
Browse files Browse the repository at this point in the history
…into 383-add-totalsegmentator2d-segmentation-downstream-task
  • Loading branch information
ioangatop committed May 6, 2024
2 parents 45329fe + 1f530bf commit 23190fb
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 64 deletions.
17 changes: 6 additions & 11 deletions src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ def __init__(
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

@override
def configure_model(self) -> Any:
if self.backbone is not None:
grad.deactivate_requires_grad(self.backbone)

@override
def configure_optimizers(self) -> Any:
parameters = list(self.head.parameters())
parameters = self.head.parameters()
optimizer = self.optimizer(parameters)
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
Expand All @@ -66,11 +71,6 @@ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens
features = tensor if self.backbone is None else self.backbone(tensor)
return self.head(features).squeeze(-1)

@override
def on_fit_start(self) -> None:
if self.backbone is not None:
grad.deactivate_requires_grad(self.backbone)

@override
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)
Expand All @@ -88,11 +88,6 @@ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.T
tensor = INPUT_BATCH(*batch).data
return tensor if self.backbone is None else self.backbone(tensor)

@override
def on_fit_end(self) -> None:
if self.backbone is not None:
grad.activate_requires_grad(self.backbone)

def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
"""Performs a model forward step and calculates the loss.
Expand Down
6 changes: 3 additions & 3 deletions src/eva/vision/data/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dataset related function and helper functions."""

from typing import List, Tuple
from typing import List, Sequence, Tuple


def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
return ranges


def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
"""Unpacks a list of ranges to individual indices.
Args:
ranges: The list of ranges to produce the indices from.
ranges: A sequence of ranges to produce the indices from.
Return:
A list of the indices.
Expand Down
107 changes: 59 additions & 48 deletions src/eva/vision/data/datasets/segmentation/total_segmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
class TotalSegmentator2D(base.ImageSegmentation):
"""TotalSegmentator 2D segmentation dataset."""

_train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
"""Train range indices."""

_val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
"""Validation range indices."""

_n_slices_per_image: int = 20
"""The amount of slices to sample per 3D CT scan image."""
_expected_dataset_lengths: Dict[str, int] = {
"train_small": 29892,
"val_small": 6480,
}
"""Dataset version and split to the expected size."""

_resources_full: List[structs.DownloadResource] = [
structs.DownloadResource(
Expand All @@ -49,7 +46,7 @@ def __init__(
self,
root: str,
split: Literal["train", "val"] | None,
version: Literal["small", "full"] = "small",
version: Literal["small", "full"] | None = "small",
download: bool = False,
as_uint8: bool = True,
transforms: Callable | None = None,
Expand All @@ -60,7 +57,8 @@ def __init__(
root: Path to the root directory of the dataset. The dataset will
be downloaded and extracted here, if it does not already exist.
split: Dataset split to use. If `None`, the entire dataset is used.
version: The version of the dataset to initialize.
version: The version of the dataset to initialize. If `None`, it will
use the files located at root as is and wont perform any checks.
download: Whether to download the data for the specified split.
Note that the download will be executed only by additionally
calling the :meth:`prepare_data` method and if the data does not
Expand All @@ -78,7 +76,7 @@ def __init__(
self._as_uint8 = as_uint8

self._samples_dirs: List[str] = []
self._indices: List[int] = []
self._indices: List[Tuple[int, int]] = []

@functools.cached_property
@override
Expand All @@ -99,7 +97,8 @@ def class_to_idx(self) -> Dict[str, int]:

@override
def filename(self, index: int) -> str:
sample_dir = self._samples_dirs[self._indices[index]]
sample_idx, _ = self._indices[index]
sample_dir = self._samples_dirs[sample_idx]
return os.path.join(sample_dir, "ct.nii.gz")

@override
Expand All @@ -114,21 +113,24 @@ def configure(self) -> None:

@override
def validate(self) -> None:
if self._version is None:
return

_validators.check_dataset_integrity(
self,
length=1660 if self._split == "train" else 400,
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
n_classes=117,
first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
)

@override
def __len__(self) -> int:
return len(self._indices) * self._n_slices_per_image
return len(self._indices)

@override
def load_image(self, index: int) -> tv_tensors.Image:
image_path = self._get_image_path(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
image_path = self._get_image_path(sample_index)
image_array = io.read_nifti_slice(image_path, slice_index)
if self._as_uint8:
image_array = convert.to_8bit(image_array)
Expand All @@ -137,8 +139,8 @@ def load_image(self, index: int) -> tv_tensors.Image:

@override
def load_mask(self, index: int) -> tv_tensors.Mask:
masks_dir = self._get_masks_dir(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
masks_dir = self._get_masks_dir(sample_index)
mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
one_hot_encoded = np.concatenate(
[io.read_nifti_slice(path, slice_index) for path in mask_paths],
Expand All @@ -149,27 +151,20 @@ def load_mask(self, index: int) -> tv_tensors.Mask:
segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2)
return tv_tensors.Mask(segmentation_label)

def _get_masks_dir(self, index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._get_sample_dir(index)
return os.path.join(self._root, sample_dir, "segmentations")

def _get_image_path(self, index: int) -> str:
def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._get_sample_dir(index)
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "ct.nii.gz")

def _get_sample_dir(self, index: int) -> str:
"""Returns the corresponding sample directory."""
sample_index = self._indices[index // self._n_slices_per_image]
return self._samples_dirs[sample_index]
def _get_masks_dir(self, sample_index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "segmentations")

def _get_sample_slice_index(self, index: int) -> int:
"""Returns the corresponding slice index."""
image_path = self._get_image_path(index)
total_slices = io.fetch_total_nifti_slices(image_path)
slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
return slice_indices[index % self._n_slices_per_image]
def _get_sample_total_slices(self, sample_index: int) -> int:
"""Returns the total amount of slices of a sample."""
image_path = self._get_image_path(sample_index)
return io.fetch_total_nifti_slices(image_path)

def _fetch_samples_dirs(self) -> List[str]:
"""Returns the name of all the samples of all the splits of the dataset."""
Expand All @@ -180,29 +175,45 @@ def _fetch_samples_dirs(self) -> List[str]:
]
return sorted(sample_filenames)

def _create_indices(self) -> List[int]:
"""Builds the dataset indices for the specified split."""
split_index_ranges = {
"train": self._train_index_ranges,
"val": self._val_index_ranges,
None: [(0, 103)],
}
index_ranges = split_index_ranges.get(self._split)
if index_ranges is None:
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
def _get_split_indices(self) -> List[int]:
"""Returns the samples indices that corresponding the dataset split and version."""
key = f"{self._split}_{self._version}"
match key:
case "train_small":
index_ranges = [(0, 83)]
case "val_small":
index_ranges = [(83, 102)]
case _:
index_ranges = [(0, len(self._samples_dirs))]

return _utils.ranges_to_indices(index_ranges)

def _create_indices(self) -> List[Tuple[int, int]]:
"""Builds the dataset indices for the specified split.
Returns:
A list of tuples, where the first value indicates the
sample index which the second its corresponding slice
index.
"""
indices = [
(sample_idx, slide_idx)
for sample_idx in self._get_split_indices()
for slide_idx in range(self._get_sample_total_slices(sample_idx))
]
return indices

def _download_dataset(self) -> None:
"""Downloads the dataset."""
dataset_resources = {
"small": self._resources_small,
"full": self._resources_full,
None: (0, 103),
}
resources = dataset_resources.get(self._version)
resources = dataset_resources.get(self._version or "")
if resources is None:
raise ValueError("Invalid data version. Use 'small' or 'full'.")
raise ValueError(
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
)

for resource in resources:
if os.path.isdir(self._root):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize(
"split, expected_length",
[("train", 1660), ("val", 400), (None, 2060)],
[("train", 9), ("val", 9), (None, 9)],
)
def test_length(
total_segmentator_dataset: datasets.TotalSegmentator2D, expected_length: int
Expand All @@ -25,6 +25,7 @@ def test_length(
[
(None, 0),
("train", 0),
("val", 0),
],
)
def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: int) -> None:
Expand All @@ -43,7 +44,7 @@ def test_sample(total_segmentator_dataset: datasets.TotalSegmentator2D, index: i

@pytest.fixture(scope="function")
def total_segmentator_dataset(
split: Literal["train", "val"], assets_path: str
split: Literal["train", "val"] | None, assets_path: str
) -> datasets.TotalSegmentator2D:
"""TotalSegmentator2D dataset fixture."""
dataset = datasets.TotalSegmentator2D(
Expand All @@ -55,6 +56,7 @@ def total_segmentator_dataset(
"Totalsegmentator_dataset_v201",
),
split=split,
version=None,
)
dataset.prepare_data()
dataset.configure()
Expand Down

0 comments on commit 23190fb

Please sign in to comment.