Skip to content

Commit

Permalink
2220 Decollate batch into list after iteration computation (29/June) (P…
Browse files Browse the repository at this point in the history
…roject-MONAI#2315)

* [WIP] 2220 Decollate batch into list of tensors after model forward (Project-MONAI#2244)

* [DLMED] add support for Activation transform

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] change all the array level post transforms to channel-first

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update all the IO, utility, post transforms

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update engines for list of dict

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update all the event handlers

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] support non-batch data

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update based on the latest APIs

Signed-off-by: Nic Ma <[email protected]>

Co-authored-by: monai-bot <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix all the unit tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix unit tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issues

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] remove unused import

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix integration test

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix integration tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix integration tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add support to copy scalar

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix wrong unit tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix doc issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix broken tests

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix unit tests

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] simplify CSV saver

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add copy_scalar_to_batch util

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] change to preprocessing and postprocessing

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] change file name to postprocessing

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix typo in doc-string

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] add Decollated back

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update to use ignite v0.4.5 metrics API

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix conflicts

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* fixes Project-MONAI#2452

Signed-off-by: Wenqi Li <[email protected]>

Co-authored-by: monai-bot <[email protected]>
  • Loading branch information
Nic-Ma and monai-bot authored Jun 29, 2021
1 parent f831c7f commit 2085a49
Show file tree
Hide file tree
Showing 67 changed files with 966 additions and 1,293 deletions.
File renamed without changes
4 changes: 2 additions & 2 deletions docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ MONAI also provides post-processing transforms for handling the model outputs. C
- Removing segmentation noise based on Connected Component Analysis, as below figure (c).
- Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e).

After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms.
![post-processing transforms](../images/post_transforms.png)
After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing.
![post-processing transforms](../images/postprocessing_transforms.png)

### 9. Integrate third-party transforms
The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`.
Expand Down
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ Generic Interfaces
.. autoclass:: BatchInverseTransform
:members:

`Decollated`
^^^^^^^^^^^^
.. autoclass:: Decollated
:members:


Vanilla Transforms
------------------
Expand Down
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
partition_dataset_classes,
pickle_hashing,
rectify_header_sform_qform,
rep_scalar_to_batch,
select_cross_validation_folds,
set_rnd,
sorted_dict,
Expand Down
5 changes: 2 additions & 3 deletions monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
"""
save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
data_: np.ndarray
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
if not isinstance(data, np.ndarray):
raise AssertionError
self._cache_dict[save_key] = data.astype(np.float32)
self._cache_dict[save_key] = np.asarray(data, dtype=float)

def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
"""Save a batch of data into the cache dictionary.
Expand Down
49 changes: 49 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pickle
import warnings
from collections import defaultdict
from copy import deepcopy
from functools import reduce
from itertools import product, starmap
from pathlib import PurePath
Expand All @@ -35,6 +36,7 @@
ensure_tuple_size,
fall_back_tuple,
first,
issequenceiterable,
optional_import,
)
from monai.utils.enums import Method
Expand Down Expand Up @@ -68,6 +70,7 @@
"pickle_hashing",
"sorted_dict",
"decollate_batch",
"rep_scalar_to_batch",
"pad_list_data_collate",
"no_collation",
"convert_tables_to_dicts",
Expand Down Expand Up @@ -367,6 +370,52 @@ def decollate_batch(batch, detach: bool = True):
raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.")


def rep_scalar_to_batch(batch_data: Union[List, Dict]) -> Union[List, Dict]:
"""
Utility tp replicate the scalar items of a list or dictionary to ensure all the items have batch dimension.
It leverages `decollate_batch(detach=False)` to filter out the scalar items.
"""

def _detect_batch_size(batch_data: Sequence):
"""
Detect the batch size from a list of data, some items in the list have batch dim, some not.
"""
for v in batch_data:
if isinstance(v, torch.Tensor) and v.ndim > 0:
return v.shape[0]
for v in batch_data:
if issequenceiterable(v):
warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.")
return len(v)
raise RuntimeError("failed to automatically detect the batch size.")

if isinstance(batch_data, dict):
batch_size = _detect_batch_size(list(batch_data.values()))
dict_batch = {}
for k, v in batch_data.items():
if decollate_batch(v, detach=False) == v and not isinstance(v, list):
# if decollating a list, the result may be the same list, so should skip this case
dict_batch[k] = [deepcopy(decollate_batch(v, detach=True)) for _ in range(batch_size)]
else:
dict_batch[k] = v

return dict_batch
elif isinstance(batch_data, list):
batch_size = _detect_batch_size(batch_data)
list_batch = []
for b in batch_data:
if decollate_batch(b, detach=False) == b and not isinstance(b, list):
list_batch.append([deepcopy(decollate_batch(b, detach=True)) for _ in range(batch_size)])
else:
list_batch.append(b)

return list_batch
# if not dict or list, just return the original data
return batch_data


def pad_list_data_collate(
batch: Sequence,
method: Union[Method, str] = Method.SYMMETRIC,
Expand Down
35 changes: 26 additions & 9 deletions monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Evaluator(Workflow):
prepare_batch: function to parse image and label for current iteration.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
post_transform: execute additional transformation for the model output data.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric: compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
Expand All @@ -62,6 +62,9 @@ class Evaluator(Workflow):
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160
decollate: whether to decollate the batch-first data to a list of data after model computation,
default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module
assumes channel-first data.
"""

Expand All @@ -73,14 +76,15 @@ def __init__(
non_blocking: bool = False,
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
post_transform: Optional[Transform] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
) -> None:
super().__init__(
device=device,
Expand All @@ -90,13 +94,14 @@ def __init__(
non_blocking=non_blocking,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
postprocessing=postprocessing,
key_metric=key_val_metric,
additional_metrics=additional_metrics,
handlers=val_handlers,
amp=amp,
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
)
mode = ForwardMode(mode)
if mode == ForwardMode.EVAL:
Expand Down Expand Up @@ -139,7 +144,7 @@ class SupervisedEvaluator(Evaluator):
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
post_transform: execute additional transformation for the model output data.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric: compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
Expand All @@ -154,6 +159,9 @@ class SupervisedEvaluator(Evaluator):
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160
decollate: whether to decollate the batch-first data to a list of data after model computation,
default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module
assumes channel-first data.
"""

Expand All @@ -167,14 +175,15 @@ def __init__(
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
inferer: Optional[Inferer] = None,
post_transform: Optional[Transform] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
) -> None:
super().__init__(
device=device,
Expand All @@ -183,14 +192,15 @@ def __init__(
non_blocking=non_blocking,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
postprocessing=postprocessing,
key_val_metric=key_val_metric,
additional_metrics=additional_metrics,
val_handlers=val_handlers,
amp=amp,
mode=mode,
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
)

self.network = network
Expand Down Expand Up @@ -224,6 +234,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

# execute forward computation
with self.mode(self.network):
if self.amp:
Expand Down Expand Up @@ -255,7 +266,7 @@ class EnsembleEvaluator(Evaluator):
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
post_transform: execute additional transformation for the model output data.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_val_metric: compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
Expand All @@ -270,6 +281,9 @@ class EnsembleEvaluator(Evaluator):
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160
decollate: whether to decollate the batch-first data to a list of data after model computation,
default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module
assumes channel-first data.
"""

Expand All @@ -284,14 +298,15 @@ def __init__(
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
inferer: Optional[Inferer] = None,
post_transform: Optional[Transform] = None,
postprocessing: Optional[Transform] = None,
key_val_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
val_handlers: Optional[Sequence] = None,
amp: bool = False,
mode: Union[ForwardMode, str] = ForwardMode.EVAL,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
) -> None:
super().__init__(
device=device,
Expand All @@ -300,14 +315,15 @@ def __init__(
non_blocking=non_blocking,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
postprocessing=postprocessing,
key_val_metric=key_val_metric,
additional_metrics=additional_metrics,
val_handlers=val_handlers,
amp=amp,
mode=mode,
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
)

self.networks = ensure_tuple(networks)
Expand Down Expand Up @@ -345,6 +361,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

for idx, network in enumerate(self.networks):
with self.mode(network):
if self.amp:
Expand Down
22 changes: 16 additions & 6 deletions monai/engines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SupervisedTrainer(Trainer):
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
post_transform: execute additional transformation for the model output data.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_train_metric: compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
Expand All @@ -86,6 +86,9 @@ class SupervisedTrainer(Trainer):
new events can be a list of str or `ignite.engine.events.EventEnum`.
event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160
decollate: whether to decollate the batch-first data to a list of data after model computation,
default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module
assumes channel-first data.
"""

Expand All @@ -102,13 +105,14 @@ def __init__(
prepare_batch: Callable = default_prepare_batch,
iteration_update: Optional[Callable] = None,
inferer: Optional[Inferer] = None,
post_transform: Optional[Transform] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
train_handlers: Optional[Sequence] = None,
amp: bool = False,
event_names: Optional[List[Union[str, EventEnum]]] = None,
event_to_attr: Optional[dict] = None,
decollate: bool = True,
) -> None:
super().__init__(
device=device,
Expand All @@ -118,13 +122,14 @@ def __init__(
non_blocking=non_blocking,
prepare_batch=prepare_batch,
iteration_update=iteration_update,
post_transform=post_transform,
postprocessing=postprocessing,
key_metric=key_train_metric,
additional_metrics=additional_metrics,
handlers=train_handlers,
amp=amp,
event_names=event_names,
event_to_attr=event_to_attr,
decollate=decollate,
)

self.network = network
Expand Down Expand Up @@ -221,14 +226,17 @@ class GanTrainer(Trainer):
g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``.
iteration_update: the callable function for every iteration, expect to accept `engine`
and `batchdata` as input parameters. if not provided, use `self._iteration()` instead.
post_transform: execute additional transformation for the model output data.
postprocessing: execute additional transformation for the model output data.
Typically, several Tensor based transforms composed by `Compose`.
key_train_metric: compute metric when every iteration completed, and save average value to
engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the
checkpoint into files.
additional_metrics: more Ignite metrics that also attach to Ignite Engine.
train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
CheckpointHandler, StatsHandler, SegmentationSaver, etc.
decollate: whether to decollate the batch-first data to a list of data after model computation,
default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module
assumes channel-first data.
"""

Expand All @@ -253,10 +261,11 @@ def __init__(
g_prepare_batch: Callable = default_make_latent,
g_update_latents: bool = True,
iteration_update: Optional[Callable] = None,
post_transform: Optional[Transform] = None,
postprocessing: Optional[Transform] = None,
key_train_metric: Optional[Dict[str, Metric]] = None,
additional_metrics: Optional[Dict[str, Metric]] = None,
train_handlers: Optional[Sequence] = None,
decollate: bool = True,
):
if not isinstance(train_data_loader, DataLoader):
raise ValueError("train_data_loader must be PyTorch DataLoader.")
Expand All @@ -273,7 +282,8 @@ def __init__(
key_metric=key_train_metric,
additional_metrics=additional_metrics,
handlers=train_handlers,
post_transform=post_transform,
postprocessing=postprocessing,
decollate=decollate,
)
self.g_network = g_network
self.g_optimizer = g_optimizer
Expand Down
2 changes: 1 addition & 1 deletion monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def default_make_latent(

def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]):
"""
Apply transform for the engine.state.batch and engine.state.output.
Apply transform on `batch` and `output`.
If `batch` and `output` are dictionaries, temporarily combine them for the transform,
otherwise, apply the transform for `output` data only.
Expand Down
Loading

0 comments on commit 2085a49

Please sign in to comment.