Skip to content

Commit

Permalink
Add TotalSegmentator2D segmentation downstream task (#413)
Browse files Browse the repository at this point in the history
  • Loading branch information
ioangatop authored May 7, 2024
1 parent ad392e8 commit 4ef3afb
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 68 deletions.
75 changes: 75 additions & 0 deletions configs/vision/dino_vit/online/total_segmentator_2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
trainer:
class_path: eva.Trainer
init_args:
n_runs: &N_RUNS ${oc.env:N_RUNS, 1}
default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/total_segmentator_2d/${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}}
max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500}
logger:
- class_path: lightning.pytorch.loggers.TensorBoardLogger
init_args:
save_dir: *OUTPUT_ROOT
name: ""
model:
class_path: eva.vision.models.modules.SemanticSegmentationModule
init_args:
encoder:
class_path: eva.vision.models.networks.encoders.TimmEncoder
init_args:
model_name: ${oc.env:TIMM_MODEL_NAME, vit_small_patch16_224}
pretrained: true
out_indices: 1
model_arguments:
dynamic_img_size: true
decoder:
class_path: eva.vision.models.networks.decoders.segmentation.ConvDecoder
init_args:
layers:
class_path: torch.nn.Conv2d
init_args:
in_channels: ${oc.env:IN_FEATURES, 384}
out_channels: &NUM_CLASSES 117
kernel_size: [1, 1]
criterion: torch.nn.CrossEntropyLoss
lr_multiplier_encoder: 0.0
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 0.0001
weight_decay: 0.05
lr_scheduler:
class_path: torch.optim.lr_scheduler.PolynomialLR
init_args:
total_iters: *MAX_STEPS
power: 0.9
metrics:
common:
- class_path: eva.metrics.AverageLoss
data:
class_path: eva.DataModule
init_args:
datasets:
train:
class_path: eva.vision.datasets.TotalSegmentator2D
init_args: &DATASET_ARGS
root: ${oc.env:DATA_ROOT, ./data}/total_segmentator
split: train
download: false
# Set `download: true` to download the dataset from https://zenodo.org/records/10047292
# The TotalSegmentator dataset is distributed under the following license:
# "Creative Commons Attribution 4.0 International"
# (see: https://creativecommons.org/licenses/by/4.0/deed.en)
transforms:
class_path: eva.vision.data.transforms.common.ResizeAndCrop
val:
class_path: eva.vision.datasets.TotalSegmentator2D
init_args:
<<: *DATASET_ARGS
split: val
dataloaders:
train:
batch_size: &BATCH_SIZE ${oc.env:BATCH_SIZE, 16}
shuffle: true
val:
batch_size: *BATCH_SIZE
predict:
batch_size: &PREDICT_BATCH_SIZE ${oc.env:PREDICT_BATCH_SIZE, 16}
5 changes: 1 addition & 4 deletions docs/DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ Add a new dependency to the `core` submodule:<br>
`pdm add <package_name>`

Add a new dependency to the `vision` submodule:<br>
`pdm add -G vision <package_name>`

After adding a new dependency, you also need to update the `pdm.lock` file:<br>
`pdm update`
`pdm add -G vision -G all <package_name>`

For more information about managing dependencies please look [here](https://pdm-project.org/latest/usage/dependency/#manage-dependencies).

Expand Down
Empty file removed main.py
Empty file.
17 changes: 6 additions & 11 deletions src/eva/core/models/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,14 @@ def __init__(
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

@override
def configure_model(self) -> Any:
if self.backbone is not None:
grad.deactivate_requires_grad(self.backbone)

@override
def configure_optimizers(self) -> Any:
parameters = list(self.head.parameters())
parameters = self.head.parameters()
optimizer = self.optimizer(parameters)
lr_scheduler = self.lr_scheduler(optimizer)
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
Expand All @@ -66,11 +71,6 @@ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens
features = tensor if self.backbone is None else self.backbone(tensor)
return self.head(features).squeeze(-1)

@override
def on_fit_start(self) -> None:
if self.backbone is not None:
grad.deactivate_requires_grad(self.backbone)

@override
def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
return self._batch_step(batch)
Expand All @@ -88,11 +88,6 @@ def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.T
tensor = INPUT_BATCH(*batch).data
return tensor if self.backbone is None else self.backbone(tensor)

@override
def on_fit_end(self) -> None:
if self.backbone is not None:
grad.activate_requires_grad(self.backbone)

def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT:
"""Performs a model forward step and calculates the loss.
Expand Down
6 changes: 3 additions & 3 deletions src/eva/vision/data/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Dataset related function and helper functions."""

from typing import List, Tuple
from typing import List, Sequence, Tuple


def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -33,11 +33,11 @@ def indices_to_ranges(indices: List[int]) -> List[Tuple[int, int]]:
return ranges


def ranges_to_indices(ranges: List[Tuple[int, int]]) -> List[int]:
def ranges_to_indices(ranges: Sequence[Tuple[int, int]]) -> List[int]:
"""Unpacks a list of ranges to individual indices.
Args:
ranges: The list of ranges to produce the indices from.
ranges: A sequence of ranges to produce the indices from.
Return:
A list of the indices.
Expand Down
109 changes: 62 additions & 47 deletions src/eva/vision/data/datasets/segmentation/total_segmentator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
class TotalSegmentator2D(base.ImageSegmentation):
"""TotalSegmentator 2D segmentation dataset."""

_train_index_ranges: List[Tuple[int, int]] = [(0, 83)]
"""Train range indices."""
_expected_dataset_lengths: Dict[str, int] = {
"train_small": 29892,
"val_small": 6480,
}
"""Dataset version and split to the expected size."""

_val_index_ranges: List[Tuple[int, int]] = [(83, 103)]
"""Validation range indices."""

_n_slices_per_image: int = 20
"""The amount of slices to sample per 3D CT scan image."""
_sample_every_n_slices: int | None = None
"""The amount of slices to sub-sample per 3D CT scan image."""

_resources_full: List[structs.DownloadResource] = [
structs.DownloadResource(
Expand All @@ -49,7 +49,7 @@ def __init__(
self,
root: str,
split: Literal["train", "val"] | None,
version: Literal["small", "full"] = "small",
version: Literal["small", "full"] | None = "small",
download: bool = False,
as_uint8: bool = True,
transforms: Callable | None = None,
Expand All @@ -60,7 +60,8 @@ def __init__(
root: Path to the root directory of the dataset. The dataset will
be downloaded and extracted here, if it does not already exist.
split: Dataset split to use. If `None`, the entire dataset is used.
version: The version of the dataset to initialize.
version: The version of the dataset to initialize. If `None`, it will
use the files located at root as is and wont perform any checks.
download: Whether to download the data for the specified split.
Note that the download will be executed only by additionally
calling the :meth:`prepare_data` method and if the data does not
Expand All @@ -78,7 +79,7 @@ def __init__(
self._as_uint8 = as_uint8

self._samples_dirs: List[str] = []
self._indices: List[int] = []
self._indices: List[Tuple[int, int]] = []

@functools.cached_property
@override
Expand All @@ -99,7 +100,8 @@ def class_to_idx(self) -> Dict[str, int]:

@override
def filename(self, index: int) -> str:
sample_dir = self._samples_dirs[self._indices[index]]
sample_idx, _ = self._indices[index]
sample_dir = self._samples_dirs[sample_idx]
return os.path.join(sample_dir, "ct.nii.gz")

@override
Expand All @@ -114,21 +116,24 @@ def configure(self) -> None:

@override
def validate(self) -> None:
if self._version is None:
return

_validators.check_dataset_integrity(
self,
length=1660 if self._split == "train" else 400,
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
n_classes=117,
first_and_last_labels=("adrenal_gland_left", "vertebrae_T9"),
)

@override
def __len__(self) -> int:
return len(self._indices) * self._n_slices_per_image
return len(self._indices)

@override
def load_image(self, index: int) -> tv_tensors.Image:
image_path = self._get_image_path(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
image_path = self._get_image_path(sample_index)
image_array = io.read_nifti_slice(image_path, slice_index)
if self._as_uint8:
image_array = convert.to_8bit(image_array)
Expand All @@ -137,8 +142,8 @@ def load_image(self, index: int) -> tv_tensors.Image:

@override
def load_mask(self, index: int) -> tv_tensors.Mask:
masks_dir = self._get_masks_dir(index)
slice_index = self._get_sample_slice_index(index)
sample_index, slice_index = self._indices[index]
masks_dir = self._get_masks_dir(sample_index)
mask_paths = (os.path.join(masks_dir, label + ".nii.gz") for label in self.classes)
one_hot_encoded = np.concatenate(
[io.read_nifti_slice(path, slice_index) for path in mask_paths],
Expand All @@ -149,27 +154,20 @@ def load_mask(self, index: int) -> tv_tensors.Mask:
segmentation_label = np.argmax(one_hot_encoded_with_bg, axis=2)
return tv_tensors.Mask(segmentation_label)

def _get_masks_dir(self, index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._get_sample_dir(index)
return os.path.join(self._root, sample_dir, "segmentations")

def _get_image_path(self, index: int) -> str:
def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._get_sample_dir(index)
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "ct.nii.gz")

def _get_sample_dir(self, index: int) -> str:
"""Returns the corresponding sample directory."""
sample_index = self._indices[index // self._n_slices_per_image]
return self._samples_dirs[sample_index]
def _get_masks_dir(self, sample_index: int) -> str:
"""Returns the directory of the corresponding masks."""
sample_dir = self._samples_dirs[sample_index]
return os.path.join(self._root, sample_dir, "segmentations")

def _get_sample_slice_index(self, index: int) -> int:
"""Returns the corresponding slice index."""
image_path = self._get_image_path(index)
total_slices = io.fetch_total_nifti_slices(image_path)
slice_indices = np.linspace(0, total_slices - 1, num=self._n_slices_per_image, dtype=int)
return slice_indices[index % self._n_slices_per_image]
def _get_number_of_slices_per_sample(self, sample_index: int) -> int:
"""Returns the total amount of slices of a sample."""
image_path = self._get_image_path(sample_index)
return io.fetch_total_nifti_slices(image_path)

def _fetch_samples_dirs(self) -> List[str]:
"""Returns the name of all the samples of all the splits of the dataset."""
Expand All @@ -180,29 +178,46 @@ def _fetch_samples_dirs(self) -> List[str]:
]
return sorted(sample_filenames)

def _create_indices(self) -> List[int]:
"""Builds the dataset indices for the specified split."""
split_index_ranges = {
"train": self._train_index_ranges,
"val": self._val_index_ranges,
None: [(0, 103)],
}
index_ranges = split_index_ranges.get(self._split)
if index_ranges is None:
raise ValueError("Invalid data split. Use 'train', 'val' or `None`.")
def _get_split_indices(self) -> List[int]:
"""Returns the samples indices that corresponding the dataset split and version."""
key = f"{self._split}_{self._version}"
match key:
case "train_small":
index_ranges = [(0, 83)]
case "val_small":
index_ranges = [(83, 102)]
case _:
index_ranges = [(0, len(self._samples_dirs))]

return _utils.ranges_to_indices(index_ranges)

def _create_indices(self) -> List[Tuple[int, int]]:
"""Builds the dataset indices for the specified split.
Returns:
A list of tuples, where the first value indicates the
sample index which the second its corresponding slice
index.
"""
indices = [
(sample_idx, slide_idx)
for sample_idx in self._get_split_indices()
for slide_idx in range(self._get_number_of_slices_per_sample(sample_idx))
if slide_idx % (self._sample_every_n_slices or 1) == 0
]
return indices

def _download_dataset(self) -> None:
"""Downloads the dataset."""
dataset_resources = {
"small": self._resources_small,
"full": self._resources_full,
None: (0, 103),
}
resources = dataset_resources.get(self._version)
resources = dataset_resources.get(self._version or "")
if resources is None:
raise ValueError("Invalid data version. Use 'small' or 'full'.")
raise ValueError(
f"Can't download data version '{self._version}'. Use 'small' or 'full'."
)

for resource in resources:
if os.path.isdir(self._root):
Expand Down
9 changes: 8 additions & 1 deletion src/eva/vision/data/transforms/common/resize_and_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torchvision.transforms.v2 as torch_transforms
from torchvision import tv_tensors


class ResizeAndCrop(torch_transforms.Compose):
Expand Down Expand Up @@ -35,7 +36,13 @@ def _build_transforms(self) -> Sequence[Callable]:
torch_transforms.ToImage(),
torch_transforms.Resize(size=self._size),
torch_transforms.CenterCrop(size=self._size),
torch_transforms.ToDtype(torch.float32, scale=True),
torch_transforms.ToDtype(
{
tv_tensors.Image: torch.float32,
tv_tensors.Mask: torch.float32,
},
scale=True,
),
torch_transforms.Normalize(
mean=self._mean,
std=self._std,
Expand Down
Loading

0 comments on commit 4ef3afb

Please sign in to comment.