From d852e344815820aabded69220d96f088575ded66 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud <23144457+nathanpainchaud@users.noreply.github.com> Date: Tue, 6 Sep 2022 01:21:10 +0200 Subject: [PATCH] Standardize difference between testing and predicting dataloaders across datasets (#75) * Refactor `VitalDataModule` internal API + breaking change for `AcdcDataModule` Removed `predict_on_test` parameter for `AcdcDataModule` and change default behavior to no predict loop, since ACDC does not support a prediction writer anyway * Include `test_step` in generic shared steps Since `predict_step` now covers cases where we want more complex prediction loops (e.g. save predictions to disk, etc.), then `test_step` can be standardized with the other `fit` steps --- vital/config/data/acdc.yaml | 1 - vital/config/experiment/mnist-mlp.yaml | 3 ++ vital/data/acdc/data_module.py | 49 ++++++-------------------- vital/data/camus/data_module.py | 38 ++++---------------- vital/data/config.py | 8 ++--- vital/data/data_module.py | 33 ++++++++--------- vital/data/mnist/data_module.py | 27 ++++---------- vital/tasks/autoencoder.py | 8 ++--- vital/tasks/classification.py | 6 ++-- vital/tasks/generic.py | 20 ++++++----- vital/tasks/segmentation.py | 8 ++--- 11 files changed, 70 insertions(+), 131 deletions(-) diff --git a/vital/config/data/acdc.yaml b/vital/config/data/acdc.yaml index 1ab81928..374ff8c2 100644 --- a/vital/config/data/acdc.yaml +++ b/vital/config/data/acdc.yaml @@ -5,4 +5,3 @@ _target_: vital.data.acdc.data_module.AcdcDataModule dataset_path: ${oc.env:ACDC_DATA_PATH} use_da: True -predict_on_test: True diff --git a/vital/config/experiment/mnist-mlp.yaml b/vital/config/experiment/mnist-mlp.yaml index 72bfd013..53bb2375 100644 --- a/vital/config/experiment/mnist-mlp.yaml +++ b/vital/config/experiment/mnist-mlp.yaml @@ -14,6 +14,9 @@ defaults: - override /task/model: mlp - override /data: mnist +test: True +predict: False + # Overwrite specific config parameters here. They will be merged with the rest of the config by Hydra. trainer: max_epochs: 300 diff --git a/vital/data/acdc/data_module.py b/vital/data/acdc/data_module.py index 34ba7c89..8609c625 100644 --- a/vital/data/acdc/data_module.py +++ b/vital/data/acdc/data_module.py @@ -1,7 +1,7 @@ from pathlib import Path -from typing import Literal, Union +from typing import Optional, Union -from torch.utils.data import DataLoader +from pytorch_lightning.trainer.states import TrainerFn from vital.data.acdc.config import Label, image_size, in_channels from vital.data.acdc.dataset import Acdc @@ -12,13 +12,12 @@ class AcdcDataModule(VitalDataModule): """Implementation of the ``VitalDataModule`` for the ACDC dataset.""" - def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, predict_on_test: bool = True, **kwargs): + def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, **kwargs): """Initializes class instance. Args: dataset_path: Path to the HDF5 dataset. use_da: Enable use of data augmentation. - predict_on_test: If `True`, get full patients at each batch during the test stage. **kwargs: Keyword arguments to pass to the parent's constructor. """ super().__init__( @@ -31,39 +30,11 @@ def __init__(self, dataset_path: Union[str, Path], use_da: bool = True, predict_ ) self._dataset_kwargs = {"path": Path(dataset_path), "use_da": use_da} - self.predict_on_test = predict_on_test - def setup(self, stage: Literal["fit", "test"]) -> None: # noqa: D102 - if stage == "fit": - self._dataset[Subset.TRAIN] = Acdc(image_set=Subset.TRAIN, **self._dataset_kwargs) - self._dataset[Subset.VAL] = Acdc(image_set=Subset.VAL, **self._dataset_kwargs) - if stage == "test": - self._dataset[Subset.TEST] = Acdc( - image_set=Subset.TEST, predict=self.predict_on_test, **self._dataset_kwargs - ) - - def train_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.TRAIN), - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.VAL), - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.TEST), - # batch_size=None returns one full patient at each step. - batch_size=None if self.predict_on_test else self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) + def setup(self, stage: Optional[str] = None) -> None: # noqa: D102 + if stage == TrainerFn.FITTING: + self.datasets[Subset.TRAIN] = Acdc(image_set=Subset.TRAIN, **self._dataset_kwargs) + if stage in [TrainerFn.FITTING, TrainerFn.VALIDATING]: + self.datasets[Subset.VAL] = Acdc(image_set=Subset.VAL, **self._dataset_kwargs) + if stage == TrainerFn.TESTING: + self.datasets[Subset.TEST] = Acdc(image_set=Subset.TEST, **self._dataset_kwargs) diff --git a/vital/data/camus/data_module.py b/vital/data/camus/data_module.py index 6b8ad261..560febc2 100644 --- a/vital/data/camus/data_module.py +++ b/vital/data/camus/data_module.py @@ -68,13 +68,13 @@ def __init__( def setup(self, stage: Optional[str] = None) -> None: # noqa: D102 if stage == TrainerFn.FITTING: - self._dataset[Subset.TRAIN] = Camus(image_set=Subset.TRAIN, **self._dataset_kwargs) + self.datasets[Subset.TRAIN] = Camus(image_set=Subset.TRAIN, **self._dataset_kwargs) if stage in [TrainerFn.FITTING, TrainerFn.VALIDATING]: - self._dataset[Subset.VAL] = Camus(image_set=Subset.VAL, **self._dataset_kwargs) + self.datasets[Subset.VAL] = Camus(image_set=Subset.VAL, **self._dataset_kwargs) if stage == TrainerFn.TESTING: - self._dataset[Subset.TEST] = Camus(image_set=Subset.TEST, **self._dataset_kwargs) + self.datasets[Subset.TEST] = Camus(image_set=Subset.TEST, **self._dataset_kwargs) if stage == TrainerFn.PREDICTING: - self._dataset[Subset.PREDICT] = Camus(image_set=Subset.TEST, predict=True, **self._dataset_kwargs) + self.datasets[Subset.PREDICT] = Camus(image_set=Subset.TEST, predict=True, **self._dataset_kwargs) def group_ids(self, subset: Subset, level: Literal["patient", "view"] = "view") -> List[str]: """Lists the IDs of the different levels of groups/clusters samples in the data can belong to. @@ -87,35 +87,11 @@ def group_ids(self, subset: Subset, level: Literal["patient", "view"] = "view") Returns: IDs of the different levels of groups/clusters samples in the data can belong to. """ - subset_data = self.dataset().get(subset, Camus(image_set=subset, **self._dataset_kwargs)) - return subset_data.list_groups(level=level) - - def train_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.TRAIN), - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, - pin_memory=True, - ) - - def val_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.VAL), - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=True, - ) - - def test_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.TEST), batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True - ) + subset_dataset = self.datasets.get(subset, Camus(image_set=subset, **self._dataset_kwargs)) + return subset_dataset.list_groups(level=level) def predict_dataloader(self) -> DataLoader: # noqa: D102 - return DataLoader( - self.dataset(subset=Subset.PREDICT), batch_size=None, num_workers=self.num_workers, pin_memory=True - ) + return DataLoader(self.datasets[Subset.PREDICT], batch_size=None, num_workers=self.num_workers, pin_memory=True) @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/vital/data/config.py b/vital/data/config.py index 468da7dc..6ad2d3f9 100644 --- a/vital/data/config.py +++ b/vital/data/config.py @@ -95,11 +95,11 @@ class DataParameters: """Class for defining parameters related to the nature of the data. Args: - in_shape: Shape of the input data (e.g. height, width, channels). - out_shape: Shape of the target data (e.g. height, width, channels). + in_shape: Shape of the input data, if constant for all items (e.g. channels, height, width). + out_shape: Shape of the target data, if constant for all items (e.g. classes, height, width). labels: Labels provided with the data, required when using segmentation task APIs. """ - in_shape: Tuple[int, ...] - out_shape: Tuple[int, ...] + in_shape: Optional[Tuple[int, ...]] = None + out_shape: Optional[Tuple[int, ...]] = None labels: Optional[Tuple[LabelEnum, ...]] = None diff --git a/vital/data/data_module.py b/vital/data/data_module.py index 81c11bc8..f59d7df5 100644 --- a/vital/data/data_module.py +++ b/vital/data/data_module.py @@ -1,11 +1,11 @@ import os from abc import ABC from argparse import ArgumentParser -from typing import Dict, Union +from typing import Dict import pytorch_lightning as pl from pytorch_lightning.utilities.argparse import add_argparse_args -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from vital.data.config import DataParameters, Subset @@ -33,25 +33,26 @@ def __init__(self, data_params: DataParameters, batch_size: int, num_workers: in self.data_params = data_params self.batch_size = batch_size self.num_workers = num_workers - self._dataset: Dict[Subset, Dataset] = {} + self.datasets: Dict[Subset, Dataset] = {} self.save_hyperparameters(ignore="data_params") - def dataset(self, subset: Subset = None) -> Union[Dict[Subset, Dataset], Dataset]: - """Returns the subsets of the data (e.g. train) and their torch ``Dataset`` handle. + def _dataloader(self, subset: Subset, shuffle: bool = False) -> DataLoader: + return DataLoader( + self.datasets[subset], + batch_size=self.batch_size, + shuffle=shuffle, + num_workers=self.num_workers, + pin_memory=True, + ) - It should not be called before ``setup``, when the datasets are set. + def train_dataloader(self) -> DataLoader: # noqa: D102 + return self._dataloader(Subset.TRAIN, shuffle=True) - Args: - subset: Specific subset for which to get the ``Dataset`` handle. - - Returns: - If ``subset`` is provided, returns the handle to a specific dataset. Otherwise, returns the mapping between - subsets of the data (e.g. train) and their torch ``Dataset`` handle. - """ - if subset is not None: - return self._dataset[subset] + def val_dataloader(self) -> DataLoader: # noqa: D102 + return self._dataloader(Subset.VAL) - return self._dataset + def test_dataloader(self) -> DataLoader: # noqa: D102 + return self._dataloader(Subset.TEST) @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: # noqa: D102 diff --git a/vital/data/mnist/data_module.py b/vital/data/mnist/data_module.py index 3671f6e9..ef178211 100644 --- a/vital/data/mnist/data_module.py +++ b/vital/data/mnist/data_module.py @@ -2,7 +2,8 @@ from typing import Callable, List, Optional, Union import torch -from torch.utils.data import DataLoader, Dataset, random_split +from pytorch_lightning.trainer.states import TrainerFn +from torch.utils.data import Dataset, random_split from torchvision import transforms as transform_lib from vital import get_vital_home @@ -50,15 +51,15 @@ def prepare_data(self): # noqa: D102 MNIST(root=self._root, train=False, download=self._download) def setup(self, stage: Optional[str] = None) -> None: # noqa: D102 - if stage == "fit": + if stage == TrainerFn.FITTING: # Initialize one dataset for train/val split transforms = self.default_transforms() if self._transforms is None else self._transforms dataset_train = MNIST(root=self._root, transform=transforms, train=True) # Split - self._dataset[Subset.TRAIN] = self._split_dataset(dataset_train) - self._dataset[Subset.VAL] = self._split_dataset(dataset_train, train=False) - if stage == "test": - self._dataset[Subset.TEST] = MNIST(root=self._root, train=False) + self.datasets[Subset.TRAIN] = self._split_dataset(dataset_train) + self.datasets[Subset.VAL] = self._split_dataset(dataset_train, train=False) + if stage == TrainerFn.TESTING: + self.datasets[Subset.TEST] = MNIST(root=self._root, transform=self.default_transforms(), train=False) def _split_dataset(self, dataset: Dataset, train: bool = True) -> Dataset: """Splits the dataset into train and validation set. @@ -108,17 +109,3 @@ def default_transforms(self) -> Callable: else: mnist_transforms = transform_lib.Compose([transform_lib.ToTensor()]) return mnist_transforms - - def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: - return DataLoader( - dataset, batch_size=self.batch_size, shuffle=shuffle, num_workers=self.num_workers, pin_memory=True - ) - - def train_dataloader(self) -> DataLoader: # noqa: D102 - return self._data_loader(self.dataset(subset=Subset.TRAIN), shuffle=True) - - def val_dataloader(self) -> DataLoader: # noqa: D102 - return self._data_loader(self.dataset(subset=Subset.VAL)) - - def test_dataloader(self) -> DataLoader: # noqa: D102 - return self._data_loader(self.dataset(subset=Subset.TEST)) diff --git a/vital/tasks/autoencoder.py b/vital/tasks/autoencoder.py index 5240bc9b..f1d9ea9d 100644 --- a/vital/tasks/autoencoder.py +++ b/vital/tasks/autoencoder.py @@ -10,11 +10,11 @@ from vital.data.config import Tags from vital.metrics.train.functional import kl_div_zmuv from vital.metrics.train.metric import DifferentiableDiceCoefficient -from vital.tasks.generic import SharedTrainEvalTask +from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data -class SegmentationAutoencoderTask(SharedTrainEvalTask): +class SegmentationAutoencoderTask(SharedStepsTask): """Generic segmentation autoencoder training and inference steps. Implements generic segmentation train/val step and inference, assuming the following conditions: @@ -104,7 +104,7 @@ def _categorical_to_input(self, x: Tensor) -> Tensor: """ return to_onehot(x, num_classes=len(self.hparams.data_params.labels)).float() - def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 + def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 # Forward out = self.model(self._categorical_to_input(batch[self.hparams.segmentation_data_tag])) @@ -144,7 +144,7 @@ def _compute_loss(self, metrics: Mapping[str, Tensor]) -> Tensor: Args: metrics: Metrics useful for computing the loss (usually a combination of metrics from - ``self._shared_train_val_step`` and ``self._compute_latent_space_metrics``). + ``self._shared_step`` and ``self._compute_latent_space_metrics``). Returns: Loss for a train/val step. diff --git a/vital/tasks/classification.py b/vital/tasks/classification.py index d873f137..76e54c8d 100644 --- a/vital/tasks/classification.py +++ b/vital/tasks/classification.py @@ -5,10 +5,10 @@ from torchmetrics.functional import accuracy from vital.data.config import Tags -from vital.tasks.generic import SharedTrainEvalTask +from vital.tasks.generic import SharedStepsTask -class ClassificationTask(SharedTrainEvalTask): +class ClassificationTask(SharedStepsTask): """Generic classification training and inference steps. Implements generic classification train/val step and inference, assuming the following conditions: @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs): def forward(self, *args, **kwargs): # noqa: D102 return self.model(*args, **kwargs) - def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 + def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 x, y = batch[Tags.img], batch[Tags.gt] # Forward diff --git a/vital/tasks/generic.py b/vital/tasks/generic.py index afdec602..d4419c81 100644 --- a/vital/tasks/generic.py +++ b/vital/tasks/generic.py @@ -7,18 +7,15 @@ from vital.utils.format.native import prefix -class SharedTrainEvalTask(VitalSystem, ABC): - """Abstract task that shares a train/val step. +class SharedStepsTask(VitalSystem, ABC): + """Abstract task that shares a train/val/test step. Implements useful generic utilities and boilerplate Lighting code: - Handling of identical train/val step results (metrics logging and printing) """ - def _shared_train_val_step(self, *args, **kwargs) -> Dict[str, Tensor]: - """Handles steps for both training and validation loops, assuming the behavior should be the same. - - For models where the behavior in training and validation is different, then override ``training_step`` and - ``validation_step`` directly (in which case ``_shared_train_val_step`` doesn't need to be implemented). + def _shared_step(self, *args, **kwargs) -> Dict[str, Tensor]: + """Handles steps for the train/val/test loops, assuming the behavior should be the same. Returns: Mapping between metric names and their values. It must contain at least a ``'loss'``, as that is the value @@ -27,13 +24,18 @@ def _shared_train_val_step(self, *args, **kwargs) -> Dict[str, Tensor]: raise NotImplementedError def training_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102 - result = prefix(self._shared_train_val_step(*args, **kwargs), "train/") + result = prefix(self._shared_step(*args, **kwargs), "train/") self.log_dict(result, **self.hparams.train_log_kwargs) # Add reference to 'train_loss' under 'loss' keyword, requested by PL to know which metric to optimize result["loss"] = result["train/loss"] return result def validation_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102 - result = prefix(self._shared_train_val_step(*args, **kwargs), "val/") + result = prefix(self._shared_step(*args, **kwargs), "val/") + self.log_dict(result, **self.hparams.val_log_kwargs) + return result + + def test_step(self, *args, **kwargs) -> Dict[str, Tensor]: # noqa: D102 + result = prefix(self._shared_step(*args, **kwargs), "test/") self.log_dict(result, **self.hparams.val_log_kwargs) return result diff --git a/vital/tasks/segmentation.py b/vital/tasks/segmentation.py index 49bc44a8..8769463f 100644 --- a/vital/tasks/segmentation.py +++ b/vital/tasks/segmentation.py @@ -8,12 +8,12 @@ from vital.data.config import Tags from vital.metrics.train.metric import DifferentiableDiceCoefficient -from vital.tasks.generic import SharedTrainEvalTask +from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data from vital.utils.image.measure import Measure -class SegmentationTask(SharedTrainEvalTask): +class SegmentationTask(SharedStepsTask): """Generic segmentation training and inference steps. Implements generic segmentation train/val step and inference, assuming the following conditions: @@ -39,7 +39,7 @@ def __init__(self, ce_weight: float = 0.1, dice_weight: float = 1, *args, **kwar def forward(self, *args, **kwargs): # noqa: D102 return self.model(*args, **kwargs) - def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 + def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 x, y = batch[Tags.img], batch[Tags.gt] # Forward @@ -112,7 +112,7 @@ def _compute_normalized_bbox(self, y: Tensor) -> Tensor: boxes.append(item_box) return torch.stack(boxes).to(y.device) - def _shared_train_val_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: + def _shared_step(self, batch: Dict[str, Tensor], batch_idx: int) -> Dict[str, Tensor]: x, y = batch[Tags.img], batch[Tags.gt] roi_bbox = self._compute_normalized_bbox(y) # Compute the target RoI bbox from the groundtruth