Skip to content

Commit

Permalink
Standardize difference between testing and predicting dataloaders acr…
Browse files Browse the repository at this point in the history
…oss datasets (#75)

* Refactor `VitalDataModule` internal API + breaking change for `AcdcDataModule`

Removed `predict_on_test` parameter for `AcdcDataModule` and change default behavior to no predict loop, since ACDC does not support a prediction writer anyway

* Include `test_step` in generic shared steps

Since `predict_step` now covers cases where we want more complex prediction loops (e.g. save predictions to disk, etc.), then `test_step` can be standardized with the other `fit` steps
  • Loading branch information
nathanpainchaud authored Sep 5, 2022
1 parent 9797673 commit d852e34
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 131 deletions.
1 change: 0 additions & 1 deletion vital/config/data/acdc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ _target_: vital.data.acdc.data_module.AcdcDataModule

dataset_path: ${oc.env:ACDC_DATA_PATH}
use_da: True
predict_on_test: True
3 changes: 3 additions & 0 deletions vital/config/experiment/mnist-mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ defaults:
- override /task/model: mlp
- override /data: mnist

test: True
predict: False

# Overwrite specific config parameters here. They will be merged with the rest of the config by Hydra.
trainer:
max_epochs: 300
Expand Down
49 changes: 10 additions & 39 deletions vital/data/acdc/data_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path
from typing import Literal, Union
from typing import Optional, Union

from torch.utils.data import DataLoader
from pytorch_lightning.trainer.states import TrainerFn

from vital.data.acdc.config import Label, image_size, in_channels
from vital.data.acdc.dataset import Acdc
Expand All @@ -12,13 +12,12 @@
class AcdcDataModule(VitalDataModule):
"""Implementation of the ``VitalDataModule`` for the ACDC dataset."""

def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, predict_on_test: bool = True, **kwargs):
def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, **kwargs):
"""Initializes class instance.
Args:
dataset_path: Path to the HDF5 dataset.
use_da: Enable use of data augmentation.
predict_on_test: If `True`, get full patients at each batch during the test stage.
**kwargs: Keyword arguments to pass to the parent's constructor.
"""
super().__init__(
Expand All @@ -31,39 +30,11 @@ def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, predict_
)

self._dataset_kwargs = {"path": Path(dataset_path), "use_da": use_da}
self.predict_on_test = predict_on_test

def setup(self, stage: Literal["fit", "test"]) -> None: # noqa: D102
if stage == "fit":
self._dataset[Subset.TRAIN] = Acdc(image_set=Subset.TRAIN, **self._dataset_kwargs)
self._dataset[Subset.VAL] = Acdc(image_set=Subset.VAL, **self._dataset_kwargs)
if stage == "test":
self._dataset[Subset.TEST] = Acdc(
image_set=Subset.TEST, predict=self.predict_on_test, **self._dataset_kwargs
)

def train_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.TRAIN),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)

def val_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.VAL),
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
)

def test_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.TEST),
# batch_size=None returns one full patient at each step.
batch_size=None if self.predict_on_test else self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
)
def setup(self, stage: Optional[str] = None) -> None: # noqa: D102
if stage == TrainerFn.FITTING:
self.datasets[Subset.TRAIN] = Acdc(image_set=Subset.TRAIN, **self._dataset_kwargs)
if stage in [TrainerFn.FITTING, TrainerFn.VALIDATING]:
self.datasets[Subset.VAL] = Acdc(image_set=Subset.VAL, **self._dataset_kwargs)
if stage == TrainerFn.TESTING:
self.datasets[Subset.TEST] = Acdc(image_set=Subset.TEST, **self._dataset_kwargs)
38 changes: 7 additions & 31 deletions vital/data/camus/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def __init__(

def setup(self, stage: Optional[str] = None) -> None: # noqa: D102
if stage == TrainerFn.FITTING:
self._dataset[Subset.TRAIN] = Camus(image_set=Subset.TRAIN, **self._dataset_kwargs)
self.datasets[Subset.TRAIN] = Camus(image_set=Subset.TRAIN, **self._dataset_kwargs)
if stage in [TrainerFn.FITTING, TrainerFn.VALIDATING]:
self._dataset[Subset.VAL] = Camus(image_set=Subset.VAL, **self._dataset_kwargs)
self.datasets[Subset.VAL] = Camus(image_set=Subset.VAL, **self._dataset_kwargs)
if stage == TrainerFn.TESTING:
self._dataset[Subset.TEST] = Camus(image_set=Subset.TEST, **self._dataset_kwargs)
self.datasets[Subset.TEST] = Camus(image_set=Subset.TEST, **self._dataset_kwargs)
if stage == TrainerFn.PREDICTING:
self._dataset[Subset.PREDICT] = Camus(image_set=Subset.TEST, predict=True, **self._dataset_kwargs)
self.datasets[Subset.PREDICT] = Camus(image_set=Subset.TEST, predict=True, **self._dataset_kwargs)

def group_ids(self, subset: Subset, level: Literal["patient", "view"] = "view") -> List[str]:
"""Lists the IDs of the different levels of groups/clusters samples in the data can belong to.
Expand All @@ -87,35 +87,11 @@ def group_ids(self, subset: Subset, level: Literal["patient", "view"] = "view")
Returns:
IDs of the different levels of groups/clusters samples in the data can belong to.
"""
subset_data = self.dataset().get(subset, Camus(image_set=subset, **self._dataset_kwargs))
return subset_data.list_groups(level=level)

def train_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.TRAIN),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
)

def val_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.VAL),
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
)

def test_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.TEST), batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True
)
subset_dataset = self.datasets.get(subset, Camus(image_set=subset, **self._dataset_kwargs))
return subset_dataset.list_groups(level=level)

def predict_dataloader(self) -> DataLoader: # noqa: D102
return DataLoader(
self.dataset(subset=Subset.PREDICT), batch_size=None, num_workers=self.num_workers, pin_memory=True
)
return DataLoader(self.datasets[Subset.PREDICT], batch_size=None, num_workers=self.num_workers, pin_memory=True)

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
Expand Down
8 changes: 4 additions & 4 deletions vital/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ class DataParameters:
"""Class for defining parameters related to the nature of the data.
Args:
in_shape: Shape of the input data (e.g. height, width, channels).
out_shape: Shape of the target data (e.g. height, width, channels).
in_shape: Shape of the input data, if constant for all items (e.g. channels, height, width).
out_shape: Shape of the target data, if constant for all items (e.g. classes, height, width).
labels: Labels provided with the data, required when using segmentation task APIs.
"""

in_shape: Tuple[int, ...]
out_shape: Tuple[int, ...]
in_shape: Optional[Tuple[int, ...]] = None
out_shape: Optional[Tuple[int, ...]] = None
labels: Optional[Tuple[LabelEnum, ...]] = None
33 changes: 17 additions & 16 deletions vital/data/data_module.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from abc import ABC
from argparse import ArgumentParser
from typing import Dict, Union
from typing import Dict

import pytorch_lightning as pl
from pytorch_lightning.utilities.argparse import add_argparse_args
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset

from vital.data.config import DataParameters, Subset

Expand Down Expand Up @@ -33,25 +33,26 @@ def __init__(self, data_params: DataParameters, batch_size: int, num_workers: in
self.data_params = data_params
self.batch_size = batch_size
self.num_workers = num_workers
self._dataset: Dict[Subset, Dataset] = {}
self.datasets: Dict[Subset, Dataset] = {}
self.save_hyperparameters(ignore="data_params")

def dataset(self, subset: Subset = None) -> Union[Dict[Subset, Dataset], Dataset]:
"""Returns the subsets of the data (e.g. train) and their torch ``Dataset`` handle.
def _dataloader(self, subset: Subset, shuffle: bool = False) -> DataLoader:
return DataLoader(
self.datasets[subset],
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=self.num_workers,
pin_memory=True,
)

It should not be called before ``setup``, when the datasets are set.
def train_dataloader(self) -> DataLoader: # noqa: D102
return self._dataloader(Subset.TRAIN, shuffle=True)

Args:
subset: Specific subset for which to get the ``Dataset`` handle.
Returns:
If ``subset`` is provided, returns the handle to a specific dataset. Otherwise, returns the mapping between
subsets of the data (e.g. train) and their torch ``Dataset`` handle.
"""
if subset is not None:
return self._dataset[subset]
def val_dataloader(self) -> DataLoader: # noqa: D102
return self._dataloader(Subset.VAL)

return self._dataset
def test_dataloader(self) -> DataLoader: # noqa: D102
return self._dataloader(Subset.TEST)

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: # noqa: D102
Expand Down
27 changes: 7 additions & 20 deletions vital/data/mnist/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Callable, List, Optional, Union

import torch
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning.trainer.states import TrainerFn
from torch.utils.data import Dataset, random_split
from torchvision import transforms as transform_lib

from vital import get_vital_home
Expand Down Expand Up @@ -50,15 +51,15 @@ def prepare_data(self): # noqa: D102
MNIST(root=self._root, train=False, download=self._download)

def setup(self, stage: Optional[str] = None) -> None: # noqa: D102
if stage == "fit":
if stage == TrainerFn.FITTING:
# Initialize one dataset for train/val split
transforms = self.default_transforms() if self._transforms is None else self._transforms
dataset_train = MNIST(root=self._root, transform=transforms, train=True)
# Split
self._dataset[Subset.TRAIN] = self._split_dataset(dataset_train)
self._dataset[Subset.VAL] = self._split_dataset(dataset_train, train=False)
if stage == "test":
self._dataset[Subset.TEST] = MNIST(root=self._root, train=False)
self.datasets[Subset.TRAIN] = self._split_dataset(dataset_train)
self.datasets[Subset.VAL] = self._split_dataset(dataset_train, train=False)
if stage == TrainerFn.TESTING:
self.datasets[Subset.TEST] = MNIST(root=self._root, transform=self.default_transforms(), train=False)

def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset:
"""Splits the dataset into train and validation set.
Expand Down Expand Up @@ -108,17 +109,3 @@ def default_transforms(self) -> Callable:
else:
mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()])
return mnist_transforms

def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader:
return DataLoader(
dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=True
)

def train_dataloader(self) -> DataLoader: # noqa: D102
return self._data_loader(self.dataset(subset=Subset.TRAIN), shuffle=True)

def val_dataloader(self) -> DataLoader: # noqa: D102
return self._data_loader(self.dataset(subset=Subset.VAL))

def test_dataloader(self) -> DataLoader: # noqa: D102
return self._data_loader(self.dataset(subset=Subset.TEST))
8 changes: 4 additions & 4 deletions vital/tasks/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from vital.data.config import Tags
from vital.metrics.train.functional import kl_div_zmuv
from vital.metrics.train.metric import DifferentiableDiceCoefficient
from vital.tasks.generic import SharedTrainEvalTask
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data


class SegmentationAutoencoderTask(SharedTrainEvalTask):
class SegmentationAutoencoderTask(SharedStepsTask):
"""Generic segmentation autoencoder training and inference steps.
Implements generic segmentation train/val step and inference, assuming the following conditions:
Expand Down Expand Up @@ -104,7 +104,7 @@ def _categorical_to_input(self, x: Tensor) -> Tensor:
"""
return to_onehot(x, num_classes=len(self.hparams.data_params.labels)).float()

def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
# Forward
out = self.model(self._categorical_to_input(batch[self.hparams.segmentation_data_tag]))

Expand Down Expand Up @@ -144,7 +144,7 @@ def _compute_loss(self, metrics: Mapping[str, Tensor]) -> Tensor:
Args:
metrics: Metrics useful for computing the loss (usually a combination of metrics from
``self._shared_train_val_step`` and ``self._compute_latent_space_metrics``).
``self._shared_step`` and ``self._compute_latent_space_metrics``).
Returns:
Loss for a train/val step.
Expand Down
6 changes: 3 additions & 3 deletions vital/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from torchmetrics.functional import accuracy

from vital.data.config import Tags
from vital.tasks.generic import SharedTrainEvalTask
from vital.tasks.generic import SharedStepsTask


class ClassificationTask(SharedTrainEvalTask):
class ClassificationTask(SharedStepsTask):
"""Generic classification training and inference steps.
Implements generic classification train/val step and inference, assuming the following conditions:
Expand All @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs):
def forward(self, *args, **kwargs): # noqa: D102
return self.model(*args, **kwargs)

def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
x, y = batch[Tags.img], batch[Tags.gt]

# Forward
Expand Down
20 changes: 11 additions & 9 deletions vital/tasks/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@
from vital.utils.format.native import prefix


class SharedTrainEvalTask(VitalSystem, ABC):
"""Abstract task that shares a train/val step.
class SharedStepsTask(VitalSystem, ABC):
"""Abstract task that shares a train/val/test step.
Implements useful generic utilities and boilerplate Lighting code:
- Handling of identical train/val step results (metrics logging and printing)
"""

def _shared_train_val_step(self, *args, **kwargs) -> Dict[str, Tensor]:
"""Handles steps for both training and validation loops, assuming the behavior should be the same.
For models where the behavior in training and validation is different, then override ``training_step`` and
``validation_step`` directly (in which case ``_shared_train_val_step`` doesn't need to be implemented).
def _shared_step(self, *args, **kwargs) -> Dict[str, Tensor]:
"""Handles steps for the train/val/test loops, assuming the behavior should be the same.
Returns:
Mapping between metric names and their values. It must contain at least a ``'loss'``, as that is the value
Expand All @@ -27,13 +24,18 @@ def _shared_train_val_step(self, *args, **kwargs) -> Dict[str, Tensor]:
raise NotImplementedError

def training_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102
result = prefix(self._shared_train_val_step(*args, **kwargs), "train/")
result = prefix(self._shared_step(*args, **kwargs), "train/")
self.log_dict(result, **self.hparams.train_log_kwargs)
# Add reference to 'train_loss' under 'loss' keyword, requested by PL to know which metric to optimize
result["loss"] = result["train/loss"]
return result

def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102
result = prefix(self._shared_train_val_step(*args, **kwargs), "val/")
result = prefix(self._shared_step(*args, **kwargs), "val/")
self.log_dict(result, **self.hparams.val_log_kwargs)
return result

def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102
result = prefix(self._shared_step(*args, **kwargs), "test/")
self.log_dict(result, **self.hparams.val_log_kwargs)
return result
8 changes: 4 additions & 4 deletions vital/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from vital.data.config import Tags
from vital.metrics.train.metric import DifferentiableDiceCoefficient
from vital.tasks.generic import SharedTrainEvalTask
from vital.tasks.generic import SharedStepsTask
from vital.utils.decorators import auto_move_data
from vital.utils.image.measure import Measure


class SegmentationTask(SharedTrainEvalTask):
class SegmentationTask(SharedStepsTask):
"""Generic segmentation training and inference steps.
Implements generic segmentation train/val step and inference, assuming the following conditions:
Expand All @@ -39,7 +39,7 @@ def __init__(self, ce_weight: float = 0.1, dice_weight: float = 1, *args, **kwar
def forward(self, *args, **kwargs): # noqa: D102
return self.model(*args, **kwargs)

def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102
x, y = batch[Tags.img], batch[Tags.gt]

# Forward
Expand Down Expand Up @@ -112,7 +112,7 @@ def _compute_normalized_bbox(self, y: Tensor) -> Tensor:
boxes.append(item_box)
return torch.stack(boxes).to(y.device)

def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]:
def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]:
x, y = batch[Tags.img], batch[Tags.gt]
roi_bbox = self._compute_normalized_bbox(y) # Compute the target RoI bbox from the groundtruth

Expand Down

0 comments on commit d852e34

Please sign in to comment.