diff --git a/docs/images/post_transforms.png b/docs/images/postprocessing_transforms.png similarity index 100% rename from docs/images/post_transforms.png rename to docs/images/postprocessing_transforms.png diff --git a/docs/source/highlights.md b/docs/source/highlights.md index 3b1160a6e5..c86eabb0ab 100644 --- a/docs/source/highlights.md +++ b/docs/source/highlights.md @@ -112,8 +112,8 @@ MONAI also provides post-processing transforms for handling the model outputs. C - Removing segmentation noise based on Connected Component Analysis, as below figure (c). - Extracting contour of segmentation result, which can be used to map to original image and evaluate the model, as below figure (d) and (e). -After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Post transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/post_transforms.ipynb) shows an example with several main post transforms. -![post-processing transforms](../images/post_transforms.png) +After applying the post-processing transforms, it's easier to compute metrics, save model output into files or visualize data in the TensorBoard. [Postprocessing transforms tutorial](https://github.com/Project-MONAI/tutorials/blob/master/modules/postprocessing_transforms.ipynb) shows an example with several main transforms for post-processing. +![post-processing transforms](../images/postprocessing_transforms.png) ### 9. Integrate third-party transforms The design of MONAI transforms emphasis code readability and usability. It works for array data or dictionary-based data. MONAI also provides `Adaptor` tools to accommodate different data format for 3rd party transforms. To convert the data shapes or types, utility transforms such as `ToTensor`, `ToNumpy`, `SqueezeDim` are also provided. So it's easy to enhance the transform chain by seamlessly integrating transforms from external packages, including: `ITK`, `BatchGenerator`, `TorchIO` and `Rising`. diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index fb02627635..0fcdb5f1ab 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -48,6 +48,11 @@ Generic Interfaces .. autoclass:: BatchInverseTransform :members: +`Decollated` +^^^^^^^^^^^^ +.. autoclass:: Decollated + :members: + Vanilla Transforms ------------------ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index a82f80213a..af42627f5f 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -56,6 +56,7 @@ partition_dataset_classes, pickle_hashing, rectify_header_sform_qform, + rep_scalar_to_batch, select_cross_validation_folds, set_rnd, sorted_dict, diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index 3701c094cd..ab2ae1d568 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -87,11 +87,10 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] """ save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 + data_: np.ndarray if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() - if not isinstance(data, np.ndarray): - raise AssertionError - self._cache_dict[save_key] = data.astype(np.float32) + self._cache_dict[save_key] = np.asarray(data, dtype=float) def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: """Save a batch of data into the cache dictionary. diff --git a/monai/data/utils.py b/monai/data/utils.py index b6b8466c32..601a5b0c92 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -16,6 +16,7 @@ import pickle import warnings from collections import defaultdict +from copy import deepcopy from functools import reduce from itertools import product, starmap from pathlib import PurePath @@ -35,6 +36,7 @@ ensure_tuple_size, fall_back_tuple, first, + issequenceiterable, optional_import, ) from monai.utils.enums import Method @@ -68,6 +70,7 @@ "pickle_hashing", "sorted_dict", "decollate_batch", + "rep_scalar_to_batch", "pad_list_data_collate", "no_collation", "convert_tables_to_dicts", @@ -367,6 +370,52 @@ def decollate_batch(batch, detach: bool = True): raise NotImplementedError(f"Unable to de-collate: {batch}, type: {type(batch)}.") +def rep_scalar_to_batch(batch_data: Union[List, Dict]) -> Union[List, Dict]: + """ + Utility tp replicate the scalar items of a list or dictionary to ensure all the items have batch dimension. + It leverages `decollate_batch(detach=False)` to filter out the scalar items. + + """ + + def _detect_batch_size(batch_data: Sequence): + """ + Detect the batch size from a list of data, some items in the list have batch dim, some not. + + """ + for v in batch_data: + if isinstance(v, torch.Tensor) and v.ndim > 0: + return v.shape[0] + for v in batch_data: + if issequenceiterable(v): + warnings.warn("batch_data doesn't contain batched Tensor data, use the length of first sequence data.") + return len(v) + raise RuntimeError("failed to automatically detect the batch size.") + + if isinstance(batch_data, dict): + batch_size = _detect_batch_size(list(batch_data.values())) + dict_batch = {} + for k, v in batch_data.items(): + if decollate_batch(v, detach=False) == v and not isinstance(v, list): + # if decollating a list, the result may be the same list, so should skip this case + dict_batch[k] = [deepcopy(decollate_batch(v, detach=True)) for _ in range(batch_size)] + else: + dict_batch[k] = v + + return dict_batch + elif isinstance(batch_data, list): + batch_size = _detect_batch_size(batch_data) + list_batch = [] + for b in batch_data: + if decollate_batch(b, detach=False) == b and not isinstance(b, list): + list_batch.append([deepcopy(decollate_batch(b, detach=True)) for _ in range(batch_size)]) + else: + list_batch.append(b) + + return list_batch + # if not dict or list, just return the original data + return batch_data + + def pad_list_data_collate( batch: Sequence, method: Union[Method, str] = Method.SYMMETRIC, diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py index 5a2000022a..6da75cb951 100644 --- a/monai/engines/evaluator.py +++ b/monai/engines/evaluator.py @@ -47,7 +47,7 @@ class Evaluator(Workflow): prepare_batch: function to parse image and label for current iteration. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the @@ -62,6 +62,9 @@ class Evaluator(Workflow): new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. """ @@ -73,7 +76,7 @@ def __init__( non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, - post_transform: Optional[Transform] = None, + postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, @@ -81,6 +84,7 @@ def __init__( mode: Union[ForwardMode, str] = ForwardMode.EVAL, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + decollate: bool = True, ) -> None: super().__init__( device=device, @@ -90,13 +94,14 @@ def __init__( non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, - post_transform=post_transform, + postprocessing=postprocessing, key_metric=key_val_metric, additional_metrics=additional_metrics, handlers=val_handlers, amp=amp, event_names=event_names, event_to_attr=event_to_attr, + decollate=decollate, ) mode = ForwardMode(mode) if mode == ForwardMode.EVAL: @@ -139,7 +144,7 @@ class SupervisedEvaluator(Evaluator): iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the @@ -154,6 +159,9 @@ class SupervisedEvaluator(Evaluator): new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. """ @@ -167,7 +175,7 @@ def __init__( prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, inferer: Optional[Inferer] = None, - post_transform: Optional[Transform] = None, + postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, @@ -175,6 +183,7 @@ def __init__( mode: Union[ForwardMode, str] = ForwardMode.EVAL, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + decollate: bool = True, ) -> None: super().__init__( device=device, @@ -183,7 +192,7 @@ def __init__( non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, - post_transform=post_transform, + postprocessing=postprocessing, key_val_metric=key_val_metric, additional_metrics=additional_metrics, val_handlers=val_handlers, @@ -191,6 +200,7 @@ def __init__( mode=mode, event_names=event_names, event_to_attr=event_to_attr, + decollate=decollate, ) self.network = network @@ -224,6 +234,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + # execute forward computation with self.mode(self.network): if self.amp: @@ -255,7 +266,7 @@ class EnsembleEvaluator(Evaluator): iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_val_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the @@ -270,6 +281,9 @@ class EnsembleEvaluator(Evaluator): new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. """ @@ -284,7 +298,7 @@ def __init__( prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, inferer: Optional[Inferer] = None, - post_transform: Optional[Transform] = None, + postprocessing: Optional[Transform] = None, key_val_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, val_handlers: Optional[Sequence] = None, @@ -292,6 +306,7 @@ def __init__( mode: Union[ForwardMode, str] = ForwardMode.EVAL, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + decollate: bool = True, ) -> None: super().__init__( device=device, @@ -300,7 +315,7 @@ def __init__( non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, - post_transform=post_transform, + postprocessing=postprocessing, key_val_metric=key_val_metric, additional_metrics=additional_metrics, val_handlers=val_handlers, @@ -308,6 +323,7 @@ def __init__( mode=mode, event_names=event_names, event_to_attr=event_to_attr, + decollate=decollate, ) self.networks = ensure_tuple(networks) @@ -345,6 +361,7 @@ def _iteration(self, engine: Engine, batchdata: Dict[str, torch.Tensor]) -> Dict # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets} + for idx, network in enumerate(self.networks): with self.mode(network): if self.amp: diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py index cce5028d54..d8b4ec9a26 100644 --- a/monai/engines/trainer.py +++ b/monai/engines/trainer.py @@ -73,7 +73,7 @@ class SupervisedTrainer(Trainer): iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the @@ -86,6 +86,9 @@ class SupervisedTrainer(Trainer): new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. """ @@ -102,13 +105,14 @@ def __init__( prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, inferer: Optional[Inferer] = None, - post_transform: Optional[Transform] = None, + postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, train_handlers: Optional[Sequence] = None, amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + decollate: bool = True, ) -> None: super().__init__( device=device, @@ -118,13 +122,14 @@ def __init__( non_blocking=non_blocking, prepare_batch=prepare_batch, iteration_update=iteration_update, - post_transform=post_transform, + postprocessing=postprocessing, key_metric=key_train_metric, additional_metrics=additional_metrics, handlers=train_handlers, amp=amp, event_names=event_names, event_to_attr=event_to_attr, + decollate=decollate, ) self.network = network @@ -221,7 +226,7 @@ class GanTrainer(Trainer): g_update_latents: Calculate G loss with new latent codes. Defaults to ``True``. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_train_metric is the main metric to compare and save the @@ -229,6 +234,9 @@ class GanTrainer(Trainer): additional_metrics: more Ignite metrics that also attach to Ignite Engine. train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: CheckpointHandler, StatsHandler, SegmentationSaver, etc. + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. """ @@ -253,10 +261,11 @@ def __init__( g_prepare_batch: Callable = default_make_latent, g_update_latents: bool = True, iteration_update: Optional[Callable] = None, - post_transform: Optional[Transform] = None, + postprocessing: Optional[Transform] = None, key_train_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, train_handlers: Optional[Sequence] = None, + decollate: bool = True, ): if not isinstance(train_data_loader, DataLoader): raise ValueError("train_data_loader must be PyTorch DataLoader.") @@ -273,7 +282,8 @@ def __init__( key_metric=key_train_metric, additional_metrics=additional_metrics, handlers=train_handlers, - post_transform=post_transform, + postprocessing=postprocessing, + decollate=decollate, ) self.g_network = g_network self.g_optimizer = g_optimizer diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 6c19b74a9a..26038ab8a5 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -134,7 +134,7 @@ def default_make_latent( def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]): """ - Apply transform for the engine.state.batch and engine.state.output. + Apply transform on `batch` and `output`. If `batch` and `output` are dictionaries, temporarily combine them for the transform, otherwise, apply the transform for `output` data only. diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 4cd70ee214..f39c720bcc 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Sequence, Union import torch @@ -17,6 +18,7 @@ from torch.utils.data.distributed import DistributedSampler from monai.config import IgniteInfo +from monai.data import decollate_batch, rep_scalar_to_batch from monai.engines.utils import IterationEvents, default_prepare_batch from monai.utils import ensure_tuple, min_version, optional_import @@ -55,7 +57,7 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona prepare_batch: function to parse image and label for every iteration. iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input parameters. if not provided, use `self._iteration()` instead. - post_transform: execute additional transformation for the model output data. + postprocessing: execute additional transformation for the model output data. Typically, several Tensor based transforms composed by `Compose`. key_metric: compute metric when every iteration completed, and save average value to engine.state.metrics when epoch completed. key_metric is the main metric to compare and save the @@ -68,6 +70,9 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona new events can be a list of str or `ignite.engine.events.EventEnum`. event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. for more details, check: https://github.com/pytorch/ignite/blob/v0.4.4.post1/ignite/engine/engine.py#L160 + decollate: whether to decollate the batch-first data to a list of data after model computation, + default to `True`. if `False`, postprocessing will be ignored as the `monai.transforms` module + assumes channel-first data. Raises: TypeError: When ``device`` is not a ``torch.Device``. @@ -86,13 +91,14 @@ def __init__( non_blocking: bool = False, prepare_batch: Callable = default_prepare_batch, iteration_update: Optional[Callable] = None, - post_transform: Optional[Callable] = None, + postprocessing: Optional[Callable] = None, key_metric: Optional[Dict[str, Metric]] = None, additional_metrics: Optional[Dict[str, Metric]] = None, handlers: Optional[Sequence] = None, amp: bool = False, event_names: Optional[List[Union[str, EventEnum]]] = None, event_to_attr: Optional[dict] = None, + decollate: bool = True, ) -> None: if iteration_update is not None: super().__init__(iteration_update) @@ -138,7 +144,12 @@ def set_sampler_epoch(engine: Engine): self.prepare_batch = prepare_batch self.amp = amp - event_names = [IterationEvents] if event_names is None else event_names + [IterationEvents] + if event_names is None: + event_names = [IterationEvents] + else: + if not isinstance(event_names, list): + raise ValueError("event_names must be a list or string or EventEnum.") + event_names += [IterationEvents] for name in event_names: if isinstance(name, str): self.register_events(name, event_to_attr=event_to_attr) @@ -147,26 +158,41 @@ def set_sampler_epoch(engine: Engine): else: raise ValueError("event_names must be a list or string or EventEnum.") - if post_transform is not None: - self._register_post_transforms(post_transform) + if decollate: + self._register_decollate() + # postprocessing can only work if `decollate=True` + if postprocessing is not None: + self._register_postprocessing(postprocessing) if key_metric is not None: self._register_metrics(key_metric, additional_metrics) if handlers is not None: self._register_handlers(handlers) - def _register_post_transforms(self, posttrans: Callable): + def _register_decollate(self): """ - Register the post transforms to the engine, will execute them as a chain when iteration completed. + Register the decollate operation for batch data, will execure after model forward and loss forward. """ @self.on(IterationEvents.MODEL_COMPLETED) - def run_post_transform(engine: Engine) -> None: - engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=posttrans, - ) + def _decollate_data(engine: Engine) -> None: + # replicate the scalar values to make sure all the items have batch dimension, then decollate + engine.state.batch = decollate_batch(rep_scalar_to_batch(engine.state.batch), detach=True) + engine.state.output = decollate_batch(rep_scalar_to_batch(engine.state.output), detach=True) + + def _register_postprocessing(self, posttrans: Callable): + """ + Register the postprocessing logic to the engine, will execute them as a chain when iteration completed. + + """ + + @self.on(IterationEvents.MODEL_COMPLETED) + def _run_postprocessing(engine: Engine) -> None: + if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): + warnings.warn("postprocessing requires `engine.state.batch` and `engine.state.outout` to be lists.") + else: + for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): + engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, posttrans) def _register_metrics(self, k_metric: Dict, add_metrics: Optional[Dict] = None): """ diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 937ee3f9b4..39d75064c2 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -22,7 +22,7 @@ from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver from .parameter_scheduler import ParamSchedulerHandler -from .post_processing import PostProcessing +from .postprocessing import PostProcessing from .regression_metrics import MeanAbsoluteError, MeanSquaredError, PeakSignalToNoiseRatio, RootMeanSquaredError from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 0fb6027ca5..96df3c1523 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -16,15 +16,9 @@ import torch from monai.config import IgniteInfo -from monai.data import CSVSaver +from monai.data import CSVSaver, decollate_batch from monai.utils import ImageMetaKey as Key -from monai.utils import ( - evenly_divisible_all_gather, - issequenceiterable, - min_version, - optional_import, - string_list_all_gather, -) +from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") @@ -109,14 +103,16 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) - if issequenceiterable(filenames): - self._filenames.extend(filenames) - outputs = self.output_transform(engine.state.output) - if outputs is not None: - if isinstance(outputs, torch.Tensor): - outputs = outputs.detach() - self._outputs.append(outputs) + meta_data = self.batch_transform(engine.state.batch) + if isinstance(meta_data, dict): + # decollate the `dictionary of list` to `list of dictionaries` + meta_data = decollate_batch(meta_data) + engine_output = self.output_transform(engine.state.output) + for m, o in zip(meta_data, engine_output): + self._filenames.append(f"{m.get(Key.FILENAME_OR_OBJ)}") + if isinstance(o, torch.Tensor): + o = o.detach() + self._outputs.append(o) def _finalize(self, engine: Engine) -> None: """ @@ -129,7 +125,7 @@ def _finalize(self, engine: Engine) -> None: if self.save_rank >= ws: raise ValueError("target save rank is greater than the distributed group size.") - outputs = torch.cat(self._outputs, dim=0) + outputs = torch.stack(self._outputs, dim=0) filenames = self._filenames if ws > 1: outputs = evenly_divisible_all_gather(outputs, concat=True) diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py index 854bb609eb..64553955b7 100644 --- a/monai/handlers/metric_logger.py +++ b/monai/handlers/metric_logger.py @@ -26,7 +26,7 @@ def _get_loss_from_output(output, loss_key: str = CommonKeys.LOSS): - return output[loss_key].item() + return output[0][loss_key] class MetricLoggerKeys(Enum): diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index fc895bee75..acfd2eb94e 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -12,9 +12,10 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union from monai.config import IgniteInfo +from monai.data import decollate_batch from monai.handlers.utils import write_metrics_reports from monai.utils import ImageMetaKey as Key -from monai.utils import ensure_tuple, issequenceiterable, min_version, optional_import, string_list_all_gather +from monai.utils import ensure_tuple, min_version, optional_import, string_list_all_gather Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed") @@ -106,9 +107,12 @@ def _started(self, engine: Engine) -> None: def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: - filenames = self.batch_transform(engine.state.batch).get(Key.FILENAME_OR_OBJ) - if issequenceiterable(filenames): - self._filenames.extend(filenames) + meta_data = self.batch_transform(engine.state.batch) + if isinstance(meta_data, dict): + # decollate the `dictionary of list` to `list of dictionaries` + meta_data = decollate_batch(meta_data) + for m in meta_data: + self._filenames.append(f"{m.get(Key.FILENAME_OR_OBJ)}") def __call__(self, engine: Engine) -> None: """ diff --git a/monai/handlers/post_processing.py b/monai/handlers/postprocessing.py similarity index 74% rename from monai/handlers/post_processing.py rename to monai/handlers/postprocessing.py index a0fe0a041e..8732c4ad80 100644 --- a/monai/handlers/post_processing.py +++ b/monai/handlers/postprocessing.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import TYPE_CHECKING, Callable from monai.config import IgniteInfo @@ -24,8 +25,8 @@ class PostProcessing: """ - Ignite handler to execute additional post processing after the post transforms in engines. - So users can insert other handlers between post transforms and this post processing handler. + Ignite handler to execute additional post processing after the post processing in engines. + So users can insert other handlers between engine postprocessing and this post processing handler. """ @@ -33,7 +34,7 @@ def __init__(self, transform: Callable) -> None: """ Args: transform: callable function to execute on the `engine.state.batch` and `engine.state.output`. - can also be composed post transforms. + can also be composed transforms. """ self.transform = transform @@ -50,8 +51,8 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - engine.state.batch, engine.state.output = engine_apply_transform( - batch=engine.state.batch, - output=engine.state.output, - transform=self.transform, - ) + if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): + warnings.warn("postprocessing requires `engine.state.batch` and `engine.state.outout` to be lists.") + else: + for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): + engine.state.batch[i], engine.state.output[i] = engine_apply_transform(b, o, self.transform) diff --git a/monai/handlers/segmentation_saver.py b/monai/handlers/segmentation_saver.py index c61cdd710e..57b773d73a 100644 --- a/monai/handlers/segmentation_saver.py +++ b/monai/handlers/segmentation_saver.py @@ -15,6 +15,7 @@ import numpy as np from monai.config import DtypeLike, IgniteInfo +from monai.data import decollate_batch from monai.transforms import SaveImage from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, min_version, optional_import @@ -120,7 +121,6 @@ def __init__( scale=scale, dtype=dtype, output_dtype=output_dtype, - save_batch=True, squeeze_end_dims=squeeze_end_dims, data_root_dir=data_root_dir, ) @@ -149,6 +149,10 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ meta_data = self.batch_transform(engine.state.batch) + if isinstance(meta_data, dict): + # decollate the `dictionary of list` to `list of dictionaries` + meta_data = decollate_batch(meta_data) engine_output = self.output_transform(engine.state.output) - self._saver(engine_output, meta_data) + for m, o in zip(meta_data, engine_output): + self._saver(o, m) self.logger.info("model outputs saved into files.") diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py index cf73f1ca63..99ed91c714 100644 --- a/monai/handlers/stats_handler.py +++ b/monai/handlers/stats_handler.py @@ -45,7 +45,7 @@ def __init__( self, epoch_print_logger: Optional[Callable[[Engine], Any]] = None, iteration_print_logger: Optional[Callable[[Engine], Any]] = None, - output_transform: Callable = lambda x: x, + output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, name: Optional[str] = None, tag_name: str = DEFAULT_TAG, @@ -63,6 +63,8 @@ def __init__( ``ignite.engine.state.output`` into a scalar to print, or a dictionary of {key: scalar}. In the latter case, the output string will be formatted as key: value. By default this value logging happens when every iteration completed. + The default behavior is to print loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to print synced epoch number with the trainer engine. @@ -175,7 +177,8 @@ def _default_iteration_print(self, engine: Engine) -> None: """ Execute iteration log operation based on Ignite engine.state data. Print the values from Ignite state.logs dict. - Default behavior is to print loss from output[1], skip if output[1] is not loss. + The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss + value for every item of the decollated list. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index 5ecaa02ca0..a3a0bf76b8 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -83,7 +83,7 @@ def __init__( epoch_interval: int = 1, iteration_event_writer: Optional[Callable[[Engine, SummaryWriter], Any]] = None, iteration_interval: int = 1, - output_transform: Callable = lambda x: x, + output_transform: Callable = lambda x: x[0], global_epoch_transform: Callable = lambda x: x, tag_name: str = DEFAULT_TAG, ) -> None: @@ -102,6 +102,8 @@ def __init__( ``ignite.engine.state.output`` into a scalar to plot, or a dictionary of {key: scalar}. In the latter case, the output string will be formatted as key: value. By default this value plotting happens when every iteration completed. + The default behavior is to print loss from output[0] as output is a decollated list + and we replicated loss value for every item of the decollated list. global_epoch_transform: a callable that is used to customize global epoch number. For example, in evaluation, the evaluator engine might want to use trainer engines epoch number when plotting epoch vs metric curves. @@ -178,7 +180,8 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> None: """ Execute iteration level event write operation based on Ignite engine.state data. - Default is to write the loss value of current iteration. + The default behavior is to print loss from output[0] as output is a decollated list and we replicated loss + value for every item of the decollated list. Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. @@ -306,7 +309,7 @@ def __call__(self, engine: Engine) -> None: """ step = self.global_iter_transform(engine.state.epoch if self.epoch_level else engine.state.iteration) - show_images = self.batch_transform(engine.state.batch)[0] + show_images = self.batch_transform(engine.state.batch)[0][self.index] if isinstance(show_images, torch.Tensor): show_images = show_images.detach().cpu().numpy() if show_images is not None: @@ -316,10 +319,17 @@ def __call__(self, engine: Engine) -> None: f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}." ) plot_2d_or_3d_image( - show_images, step, self._writer, self.index, self.max_channels, self.max_frames, "input_0" + # add batch dim and plot the first item + show_images[None], + step, + self._writer, + 0, + self.max_channels, + self.max_frames, + "input_0", ) - show_labels = self.batch_transform(engine.state.batch)[1] + show_labels = self.batch_transform(engine.state.batch)[1][self.index] if isinstance(show_labels, torch.Tensor): show_labels = show_labels.detach().cpu().numpy() if show_labels is not None: @@ -328,11 +338,9 @@ def __call__(self, engine: Engine) -> None: "batch_transform(engine.state.batch)[1] must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}." ) - plot_2d_or_3d_image( - show_labels, step, self._writer, self.index, self.max_channels, self.max_frames, "input_1" - ) + plot_2d_or_3d_image(show_labels[None], step, self._writer, 0, self.max_channels, self.max_frames, "input_1") - show_outputs = self.output_transform(engine.state.output) + show_outputs = self.output_transform(engine.state.output)[self.index] if isinstance(show_outputs, torch.Tensor): show_outputs = show_outputs.detach().cpu().numpy() if show_outputs is not None: @@ -341,8 +349,6 @@ def __call__(self, engine: Engine) -> None: "output_transform(engine.state.output) must be None or one of " f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}." ) - plot_2d_or_3d_image( - show_outputs, step, self._writer, self.index, self.max_channels, self.max_frames, "output" - ) + plot_2d_or_3d_image(show_outputs[None], step, self._writer, 0, self.max_channels, self.max_frames, "output") self._writer.flush() diff --git a/monai/handlers/transform_inverter.py b/monai/handlers/transform_inverter.py index 0d4146af03..05e74bdbf2 100644 --- a/monai/handlers/transform_inverter.py +++ b/monai/handlers/transform_inverter.py @@ -9,13 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union import torch -from torch.utils.data import DataLoader as TorchDataLoader from monai.config import IgniteInfo, KeysCollection -from monai.data.utils import no_collation from monai.engines.utils import CommonKeys, IterationEvents from monai.transforms import Invertd, InvertibleTransform from monai.utils import ensure_tuple, ensure_tuple_rep, min_version, optional_import @@ -31,7 +30,7 @@ class TransformInverter: """ Ignite handler to automatically invert `transforms`. It takes `engine.state.output` as the input data and uses the transforms information from `engine.state.batch`. - Expect both `engine.state.output` and `engine.state.batch` to be dictionary data. + Expect both `engine.state.output` and `engine.state.batch` to be list of dictionaries data. The inverted data is in-place saved back to `engine.state.output` with key: "{output_key}". And the inverted meta dict will be stored in `engine.state.batch` with key: "{meta_keys}" or "{key}_{meta_key_postfix}". @@ -41,13 +40,11 @@ class TransformInverter: def __init__( self, transform: InvertibleTransform, - loader: TorchDataLoader, output_keys: KeysCollection = CommonKeys.PRED, batch_keys: KeysCollection = CommonKeys.IMAGE, meta_keys: Optional[KeysCollection] = None, batch_meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", - collate_fn: Optional[Callable] = no_collation, nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", @@ -57,7 +54,6 @@ def __init__( """ Args: transform: a callable data transform on input data. - loader: data loader used to run transforms and generate the batch of data. output_keys: the key of expected data in `ignite.engine.output`, invert transforms on it. it also can be a list of keys, will invert transform for each of them. Default to "pred". it's in-place operation. @@ -80,8 +76,6 @@ def __init__( For example, to handle orig_key `image`, read/write `affine` matrices from the metadata `image_meta_dict` dictionary's `affine` field. the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". - collate_fn: how to collate data after inverse transformations. default won't do any collation, - so the output will be a list of PyTorch Tensor or numpy array without batch dim. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. it also can be a list of bool, each matches to the `output_keys` data. @@ -100,12 +94,10 @@ def __init__( self.inverter = Invertd( keys=output_keys, transform=transform, - loader=loader, orig_keys=batch_keys, meta_keys=meta_keys, orig_meta_keys=batch_meta_keys, meta_key_postfix=meta_key_postfix, - collate_fn=collate_fn, nearest_interp=nearest_interp, to_tensor=to_tensor, device=device, @@ -130,17 +122,23 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - # combine `batch` and `output` to temporarily act as 1 dict for post transform - data = dict(engine.state.batch) - data.update(engine.state.output) - ret = self.inverter(data) + if not isinstance(engine.state.batch, list) or not isinstance(engine.state.output, list): + warnings.warn("inverter requires `engine.state.batch` and `engine.state.outout` to be lists.") + else: + for i, (b, o) in enumerate(zip(engine.state.batch, engine.state.output)): + # combine `batch` and `output` to temporarily act as 1 dict for postprocessing + data = dict(b) + data.update(o) + ret = self.inverter(data) - for output_key, meta_key, meta_key_postfix in zip(self.output_keys, self.meta_keys, self.meta_key_postfix): - # save the inverted data into state.output - engine.state.output[output_key] = ret.get(output_key) - # save the inverted meta dict into state.batch - meta_key = meta_key or f"{output_key}_{meta_key_postfix}" - if meta_key in ret: - # FIXME: we save inverted meta dict into `batch` to be compatible with `SegmentationSaver` - # will deprecate both handlers soon - engine.state.batch[meta_key] = ret.get(meta_key) + for output_key, meta_key, meta_key_postfix in zip( + self.output_keys, self.meta_keys, self.meta_key_postfix + ): + # save the inverted data into state.output + engine.state.output[i][output_key] = ret.get(output_key) + # save the inverted meta dict into state.batch + meta_key = meta_key or f"{output_key}_{meta_key_postfix}" + if meta_key in ret: + # FIXME: we save inverted meta dict into `batch` to be compatible with `SegmentationSaver` + # will deprecate both handlers soon + engine.state.batch[i][meta_key] = ret.get(meta_key) diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 8f53e366a9..74b8c74b28 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -235,7 +235,7 @@ def _compute_op(op: str, d: np.ndarray): def from_engine(keys: KeysCollection): """ Utility function to simplify the `batch_transform` or `output_transform` args of ignite components - when handling dictionary data(for example: `engine.state.batch` or `engine.state.output`). + when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`). Users only need to set the expected keys, then it will return a callable function to extract data from dictionary and construct a tuple respectively. It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward. @@ -250,8 +250,14 @@ def from_engine(keys: KeysCollection): ) """ - - def _wrapper(output: Dict): - return tuple(output[k] for k in ensure_tuple(keys)) + keys = ensure_tuple(keys) + + def _wrapper(data): + if isinstance(data, dict): + return tuple(data[k] for k in keys) + elif isinstance(data, list) and isinstance(data[0], dict): + # if data is a list of dictionaries, extract expected keys and construct lists + ret = [[i[k] for i in data] for k in keys] + return tuple(ret) if len(ret) > 1 else ret[0] return _wrapper diff --git a/monai/networks/utils.py b/monai/networks/utils.py index 54e6776926..bb02f78de9 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -37,7 +37,7 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: """ - For a tensor `labels` of dimensions B1[spatial_dims], return a tensor of dimensions `BN[spatial_dims]` + For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]` for `num_classes` N number of classes. Example: @@ -45,8 +45,9 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.f For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. Note that this will include the background label, thus a binary mask should be treated as having 2 classes. """ - if labels.dim() <= 0: - raise AssertionError("labels should have dim of 1 or more.") + if labels.dim() == 0: + # if no channel dim, add it + labels = labels.unsqueeze(0) # if `dim` is bigger, add singleton dim at the end if labels.ndim < dim + 1: diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 76b2c626e7..fb1ff25765 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -180,7 +180,7 @@ ThresholdIntensityDict, ) from .inverse import InvertibleTransform -from .inverse_batch_transform import BatchInverseTransform +from .inverse_batch_transform import BatchInverseTransform, Decollated from .io.array import LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .post.array import ( @@ -199,9 +199,6 @@ AsDiscreted, AsDiscreteD, AsDiscreteDict, - Decollated, - DecollateD, - DecollateDict, Ensembled, Invertd, InvertD, diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index ac1ff2a944..c6dad2fcd0 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -19,10 +19,10 @@ from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate from monai.transforms.croppad.batch import PadListDataCollate from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Transform +from monai.transforms.transform import MapTransform, Transform from monai.utils import first -__all__ = ["BatchInverseTransform"] +__all__ = ["BatchInverseTransform", "Decollated"] class _BatchInverseDataset(Dataset): @@ -87,7 +87,6 @@ def __init__( self.pad_collation_used = loader.collate_fn.__doc__ == pad_list_data_collate.__doc__ def __call__(self, data: Dict[str, Any]) -> Any: - decollated_data = decollate_batch(data, detach=self.detach) inv_ds = _BatchInverseDataset(decollated_data, self.transform, self.pad_collation_used) inv_loader = DataLoader( @@ -100,3 +99,22 @@ def __call__(self, data: Dict[str, Any]) -> Any: if "equal size" in re_str: re_str += "\nMONAI hint: try creating `BatchInverseTransform` with `collate_fn=lambda x: x`." raise RuntimeError(re_str) + + +class Decollated(MapTransform): + """ + Decollate a batch of data. + Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. + + Args: + detach: whether to detach the tensors. Scalars tensors will be detached into number types + instead of torch tensors. + + """ + + def __init__(self, keys="", detach: bool = True) -> None: + super().__init__(keys=keys) + self.detach = detach + + def __call__(self, data: dict): + return decollate_batch(data, detach=self.detach) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 8a9c8d10a0..d902300ecb 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -176,14 +176,14 @@ def __call__( class SaveImage(Transform): """ Save transformed data into files, support NIfTI and PNG formats. - It can work for both numpy array and PyTorch Tensor in both pre-transform chain - and post transform chain. + It can work for both numpy array and PyTorch Tensor in both preprocessing transform + chain and postprocessing transform chain. The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, where the input image name is extracted from the provided meta data dictionary. If no meta data provided, use index from 0 as the filename prefix. It can also save a list of PyTorch Tensor or numpy array without `batch dim`. - Note: image should include channel dimension: [B],C,H,W,[D]. + Note: image should be channel-first shape: [C,H,W,[D]]. Args: output_dir: output image directory. @@ -218,8 +218,6 @@ class SaveImage(Transform): it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. - save_batch: whether the import image is a batch data, default to `False`. - usually pre-transforms run for channel first data, while post-transforms run for batch data. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, @@ -250,7 +248,6 @@ def __init__( scale: Optional[int] = None, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, - save_batch: bool = False, squeeze_end_dims: bool = True, data_root_dir: str = "", print_log: bool = True, @@ -284,8 +281,6 @@ def __init__( else: raise ValueError(f"unsupported output extension: {output_ext}.") - self.save_batch = save_batch - def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: @@ -293,19 +288,4 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic meta_data: key-value pairs of meta_data corresponding to the data. """ - if isinstance(img, (tuple, list)): - # if a list of data in shape: [channel, H, W, [D]], save every item separately - meta_: Optional[Dict] = None - for i, d in enumerate(img): - if isinstance(meta_data, dict): - meta_ = {k: meta_data[k][i] for k in meta_data} - elif isinstance(meta_data, (list, tuple)): - meta_ = meta_data[i] - else: - meta_ = meta_data - self.saver.save(d, meta_) - else: - if self.save_batch: - self.saver.save_batch(img, meta_data) - else: - self.saver.save(img, meta_data) + self.saver.save(img, meta_data) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 39649e3858..d2257e7c7d 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -133,7 +133,7 @@ class SaveImaged(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`. Note: - Image should include channel dimension: [B],C,H,W,[D]. + Image should be channel-first shape: [C,H,W,[D]]. If the data is a patch of big image, will append the patch index to filename. Args: @@ -181,8 +181,6 @@ class SaveImaged(MapTransform): it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. - save_batch: whether the import image is a batch data, default to `False`. - usually pre-transforms run for channel first data, while post-transforms run for batch data. allow_missing_keys: don't raise exception if key is missing. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and @@ -217,7 +215,6 @@ def __init__( scale: Optional[int] = None, dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, - save_batch: bool = False, allow_missing_keys: bool = False, squeeze_end_dims: bool = True, data_root_dir: str = "", @@ -236,7 +233,6 @@ def __init__( scale=scale, dtype=dtype, output_dtype=output_dtype, - save_batch=save_batch, squeeze_end_dims=squeeze_end_dims, data_root_dir=data_root_dir, print_log=print_log, diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index dd4f7afd9d..8913a1a041 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -93,10 +93,7 @@ def __call__( if sigmoid or self.sigmoid: img = torch.sigmoid(img) if softmax or self.softmax: - # add channel dim if not existing - if img.ndimension() == 1: - img = img.unsqueeze(-1) - img = torch.softmax(img, dim=1) + img = torch.softmax(img, dim=0) act_func = self.other if other is None else other if act_func is not None: @@ -166,13 +163,13 @@ def __call__( """ if argmax or self.argmax: - img = torch.argmax(img, dim=1, keepdim=True) + img = torch.argmax(img, dim=0, keepdim=True) if to_onehot or self.to_onehot: _nclasses = self.n_classes if n_classes is None else n_classes if not isinstance(_nclasses, int): raise AssertionError("One of self.n_classes or n_classes must be an integer") - img = one_hot(img, _nclasses) + img = one_hot(img, num_classes=_nclasses, dim=0) if threshold_values or self.threshold_values: img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh) @@ -185,9 +182,9 @@ class KeepLargestConnectedComponent(Transform): Keeps only the largest connected component in the image. This transform can be used as a post-processing step to clean up over-segment areas in model output. - The input is assumed to be a PyTorch Tensor: - 1) With shape (batch_size, 1, spatial_dim1[, spatial_dim2, ...]) and the values correspond to expected labels. - 2) With shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]) and the values should be 0, 1 on each labels. + The input is assumed to be a channel-first PyTorch Tensor: + 1) With shape (1, spatial_dim1[, spatial_dim2, ...]) and the values correspond to expected labels. + 2) With shape (C, spatial_dim1[, spatial_dim2, ...]) and the values should be 0, 1 on each labels. Note: For single channel data, 0 will be treated as background and the over-segment pixels will be set to 0. @@ -249,15 +246,13 @@ def __init__( def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: - img: shape must be (batch_size, C, spatial_dim1[, spatial_dim2, ...]). + img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Returns: - A PyTorch Tensor with shape (batch_size, C, spatial_dim1[, spatial_dim2, ...]). + A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - channel_dim = 1 - if img.shape[channel_dim] == 1: - - img = torch.squeeze(img, dim=channel_dim) + if img.shape[0] == 1: + img = torch.squeeze(img, dim=0) if self.independent: for i in self.applied_labels: @@ -270,22 +265,23 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: foreground += (img == i).type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) img[foreground != mask] = 0 - output = torch.unsqueeze(img, dim=channel_dim) + + output = torch.unsqueeze(img, dim=0) else: # one-hot data is assumed to have binary value in each channel if self.independent: for i in self.applied_labels: - foreground = img[:, i, ...].type(torch.uint8) + foreground = img[i, ...].type(torch.uint8) mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[:, i, ...][foreground != mask] = 0 + img[i, ...][foreground != mask] = 0 else: - applied_img = img[:, self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=channel_dim) + applied_img = img[self.applied_labels, ...].type(torch.uint8) + foreground = torch.any(applied_img, dim=0) mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=channel_dim) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=channel_dim) + background_mask = torch.unsqueeze(foreground != mask, dim=0) + background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) applied_img[background_mask] = 0 - img[:, self.applied_labels, ...] = applied_img.type(img.type()) + img[self.applied_labels, ...] = applied_img.type(img.type()) output = img return output @@ -312,10 +308,10 @@ def __init__(self, kernel_type: str = "Laplace") -> None: def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: - img: torch tensor data to extract the contour, with shape: [batch_size, channels, height, width[, depth]] + img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]] Raises: - ValueError: When ``image`` ndim is not one of [4, 5]. + ValueError: When ``image`` ndim is not one of [3, 4]. Returns: A torch tensor with the same shape as img, note: @@ -325,43 +321,44 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: ideally the edge should be thin enough, but now it has a thickness. """ - channels = img.shape[1] - if img.ndimension() == 4: + channels = img.shape[0] + img_ = img.unsqueeze(0) + if img.ndimension() == 3: kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32, device=img.device) kernel = kernel.repeat(channels, 1, 1, 1) - contour_img = F.conv2d(img, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) - elif img.ndimension() == 5: + contour_img = F.conv2d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) + elif img.ndimension() == 4: kernel = -1 * torch.ones(3, 3, 3, dtype=torch.float32, device=img.device) kernel[1, 1, 1] = 26 kernel = kernel.repeat(channels, 1, 1, 1, 1) - contour_img = F.conv3d(img, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) + contour_img = F.conv3d(img_, kernel, bias=None, stride=1, padding=1, dilation=1, groups=channels) else: raise ValueError(f"Unsupported img dimension: {img.ndimension()}, available options are [4, 5].") contour_img.clamp_(min=0.0, max=1.0) - return contour_img + return contour_img.squeeze(0) class MeanEnsemble(Transform): """ Execute mean ensemble on the input data. - The input data can be a list or tuple of PyTorch Tensor with shape: [B, C[, H, W, D]], - Or a single PyTorch Tensor with shape: [E, B, C[, H, W, D]], the `E` dimension represents + The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], + Or a single PyTorch Tensor with shape: [E, C[, H, W, D]], the `E` dimension represents the output data from different models. Typically, the input data is model output of segmentation task or classification task. And it also can support to add `weights` for the input data. Args: - weights: can be a list or tuple of numbers for input data with shape: [E, B, C, H, W[, D]]. + weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]]. or a Numpy ndarray or a PyTorch Tensor data. the `weights` will be added to input data from highest dimension, for example: 1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data. - 2. if the `weights` has 3 dimensions, it will be added to `E`, `B` and `C` dimensions. + 2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions. it's a typical practice to add weights for different classes: to ensemble 3 segmentation model outputs, every output has 4 channels(classes), - so the input data shape can be: [3, B, 4, H, W, D]. - and add different `weights` for different classes, so the `weights` shape can be: [3, 1, 4]. - for example: `weights = [[[1, 2, 3, 4]], [[4, 3, 2, 1]], [[1, 1, 1, 1]]]`. + so the input data shape can be: [3, 4, H, W, D]. + and add different `weights` for different classes, so the `weights` shape can be: [3, 4]. + for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`. """ @@ -385,8 +382,8 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te class VoteEnsemble(Transform): """ Execute vote ensemble on the input data. - The input data can be a list or tuple of PyTorch Tensor with shape: [B[, C, H, W, D]], - Or a single PyTorch Tensor with shape: [E, B[, C, H, W, D]], the `E` dimension represents + The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]], + Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents the output data from different models. Typically, the input data is model output of segmentation task or classification task. @@ -409,19 +406,19 @@ def __call__(self, img: Union[Sequence[torch.Tensor], torch.Tensor]) -> torch.Te img_ = torch.stack(img) if isinstance(img, (tuple, list)) else torch.as_tensor(img) if self.num_classes is not None: has_ch_dim = True - if img_.ndimension() > 2 and img_.shape[2] > 1: + if img_.ndimension() > 1 and img_.shape[1] > 1: warnings.warn("no need to specify num_classes for One-Hot format data.") else: - if img_.ndimension() == 2: + if img_.ndimension() == 1: # if no channel dim, need to remove channel dim after voting has_ch_dim = False - img_ = one_hot(img_, self.num_classes, dim=2) + img_ = one_hot(img_, self.num_classes, dim=1) img_ = torch.mean(img_.float(), dim=0) if self.num_classes is not None: # if not One-Hot, use "argmax" to vote the most common class - return torch.argmax(img_, dim=1, keepdim=has_ch_dim) + return torch.argmax(img_, dim=0, keepdim=has_ch_dim) # for One-Hot data, round the float number to 0 or 1 return torch.round(img_) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index c989ea67f9..f456e2ace3 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -21,13 +21,10 @@ import numpy as np import torch -from torch.utils.data import DataLoader as TorchDataLoader from monai.config import KeysCollection from monai.data.csv_saver import CSVSaver -from monai.data.utils import decollate_batch, no_collation from monai.transforms.inverse import InvertibleTransform -from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.post.array import ( Activations, AsDiscrete, @@ -66,9 +63,6 @@ "MeanEnsembleDict", "VoteEnsembleD", "VoteEnsembleDict", - "DecollateD", - "DecollateDict", - "Decollated", "ProbNMSd", "ProbNMSD", "ProbNMSDict", @@ -302,16 +296,16 @@ def __init__( if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`. output_key: the key to store ensemble result in the dictionary. if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default. - weights: can be a list or tuple of numbers for input data with shape: [E, B, C, H, W[, D]]. + weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]]. or a Numpy ndarray or a PyTorch Tensor data. the `weights` will be added to input data from highest dimension, for example: 1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data. - 2. if the `weights` has 3 dimensions, it will be added to `E`, `B` and `C` dimensions. + 2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions. it's a typical practice to add weights for different classes: to ensemble 3 segmentation model outputs, every output has 4 channels(classes), - so the input data shape can be: [3, B, 4, H, W, D]. - and add different `weights` for different classes, so the `weights` shape can be: [3, 1, 4]. - for example: `weights = [[[1, 2, 3, 4]], [[4, 3, 2, 1]], [[1, 1, 1, 1]]]`. + so the input data shape can be: [3, 4, H, W, D]. + and add different `weights` for different classes, so the `weights` shape can be: [3, 4]. + for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`. """ ensemble = MeanEnsemble(weights=weights) @@ -340,26 +334,6 @@ def __init__( super().__init__(keys, ensemble, output_key) -class Decollated(MapTransform): - """ - Decollate a batch of data. - - Note that unlike most MapTransforms, this will decollate all data, so keys are not needed. - - Args: - detach: whether to detach the tensors. Scalars tensors will be detached into number types - instead of torch tensors. - - """ - - def __init__(self, keys="", detach: bool = True) -> None: - super().__init__(keys=keys) - self.detach = detach - - def __call__(self, data: dict): - return decollate_batch(data, detach=self.detach) - - class ProbNMSd(MapTransform): """ Performs probability based non-maximum suppression (NMS) on the probabilities map via @@ -418,11 +392,11 @@ def __call__(self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]): class Invertd(MapTransform): """ Utility transform to automatically invert the previously applied transforms. - When applying pre-transforms on a orig_key(like: `image`, `label`, etc.), we record the context + When applying preprocessing transforms on a orig_key(like: `image`, `label`, etc.), we record the context information of applied transforms in a dictionary in the input data dictionary with the key - "{orig_key}_transforms". This post transform will extract the transform context information of `orig_keys` + "{orig_key}_transforms". This transform will extract the transform context information of `orig_keys` then invert the transforms(got from this context information) on the `keys` data. - Typical usage is to invert the pre-transforms(applied on input `image`) on the model `pred` data. + Typical usage is to invert the preprocessing transforms(applied on input `image`) on the model `pred` data. The output of the inverted data and metadata will be stored at `keys` and `meta_keys` respectively. To correctly invert the transforms, the information of the previously applied transforms should be @@ -434,7 +408,7 @@ class Invertd(MapTransform): Note: According to the `collate_fn`, this transform may return a list of Tensor without batch dim, - thus some following post transforms may not support a list of Tensor, and users can leverage the + thus some following transforms may not support a list of Tensor, and users can leverage the `post_func` arg for basic processing logic. This transform needs to extract the context information of applied transforms and the meta data @@ -447,12 +421,10 @@ def __init__( self, keys: KeysCollection, transform: InvertibleTransform, - loader: TorchDataLoader, orig_keys: KeysCollection, meta_keys: Optional[KeysCollection] = None, orig_meta_keys: Optional[KeysCollection] = None, meta_key_postfix: str = "meta_dict", - collate_fn: Optional[Callable] = no_collation, nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", @@ -465,7 +437,6 @@ def __init__( keys: the key of expected data in the dict, invert transforms on it, in-place operation. it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"]. transform: the previous callable transform that applied on input data. - loader: data loader used to run transforms and generate the batch of data. orig_keys: the key of the original input data in the dict. will get the applied transform information for this input data, then invert them for the expected data with `keys`. It can also be a list of keys, each matches to the `keys` data. @@ -484,8 +455,6 @@ def __init__( For example, to handle orig_key `image`, read/write `affine` matrices from the metadata `image_meta_dict` dictionary's `affine` field. the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". - collate_fn: how to collate data after inverse transformations. default won't do any collation, - so the output will be a list of PyTorch Tensor or numpy array without batch dim. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. it also can be a list of bool, each matches to the `keys` data. @@ -503,13 +472,9 @@ def __init__( """ super().__init__(keys, allow_missing_keys) + if not isinstance(transform, InvertibleTransform): + raise ValueError("transform is not invertible, can't invert transform for the data.") self.transform = transform - self.inverter = BatchInverseTransform( - transform=transform, - loader=loader, - collate_fn=collate_fn, - num_workers=num_workers, - ) self.orig_keys = ensure_tuple_rep(orig_keys, len(self.keys)) self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) if len(self.keys) != len(self.meta_keys): @@ -572,21 +537,14 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: input_dict[orig_meta_key] = d[orig_meta_key] with allow_missing_keys_mode(self.transform): # type: ignore - inverted = self.inverter(input_dict) + inverted = self.transform.inverse(input_dict) # save the inverted data - if isinstance(inverted, (tuple, list)): - d[key] = [ - post_func(self._totensor(i[orig_key]).to(device) if to_tensor else i[orig_key]) for i in inverted - ] - # save the inverted meta dict - if orig_meta_key in d: - d[meta_key] = [i.get(orig_meta_key) for i in inverted] - else: - d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) - # save the inverted meta dict - if orig_meta_key in d: - d[meta_key] = inverted.get(orig_meta_key) + d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) + # save the inverted meta dict + if orig_meta_key in d: + d[meta_key] = inverted.get(orig_meta_key) + return d @@ -624,14 +582,14 @@ def __init__( the meta data is a dictionary object which contains: filename, original_shape, etc. this arg only works when `meta_keys=None`. if no corresponding metadata, set to `None`. saver: the saver instance to save classification results, if None, create a CSVSaver internally. - the saver must provide `save_batch(batch_data, meta_data)` APIs. + the saver must provide `save(data, meta_data)` and `finalize()` APIs. output_dir: if `saver=None`, specify the directory to save the CSV file. filename: if `saver=None`, specify the name of the saved CSV file. overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True, will clear the file before saving. otherwise, will apend new content to the CSV file. flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately in this transform and clear the cache. default to True. - If False, may need user to call `saver.finalize()` manually then. + If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler. allow_missing_keys: don't raise exception if key is missing. """ @@ -639,6 +597,7 @@ def __init__( if len(self.keys) != 1: raise ValueError("only 1 key is allowed when saving the classification result.") self.saver = saver or CSVSaver(output_dir, filename, overwrite, flush) + self.flush = flush self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) @@ -648,7 +607,9 @@ def __call__(self, data): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None - self.saver.save_batch(batch_data=d[key], meta_data=meta_data) + self.saver.save(data=d[key], meta_data=meta_data) + if self.flush: + self.saver.finalize() return d @@ -668,6 +629,5 @@ def get_saver(self): MeanEnsembleD = MeanEnsembleDict = MeanEnsembled ProbNMSD = ProbNMSDict = ProbNMSd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled -DecollateD = DecollateDict = Decollated InvertD = InvertDict = Invertd SaveClassificationD = SaveClassificationDict = SaveClassificationd diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index aad415b37f..e2fc9241e3 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -241,36 +241,24 @@ class SplitChannel(Transform): """ Split Numpy array or PyTorch Tensor data according to the channel dim. It can help applying different following transforms to different channels. - Channel number must be greater than 1. Args: - channel_dim: which dimension of input image is the channel, default to None - to automatically select: if data is numpy array, channel_dim is 0 as - `numpy array` is used in the pre transforms, if PyTorch Tensor, channel_dim - is 1 as in most of the cases `Tensor` is uses in the post transforms. + channel_dim: which dimension of input image is the channel, default to 0. + """ - def __init__(self, channel_dim: Optional[int] = None) -> None: + def __init__(self, channel_dim: int = 0) -> None: self.channel_dim = channel_dim def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]: - if self.channel_dim is None: - # automatically select the default channel dim based on data type - if isinstance(img, torch.Tensor): - channel_dim = 1 - else: - channel_dim = 0 - else: - channel_dim = self.channel_dim - - n_classes = img.shape[channel_dim] + n_classes = img.shape[self.channel_dim] if n_classes <= 1: raise RuntimeError("input image does not contain multiple channels.") outputs = [] slices = [slice(None)] * len(img.shape) for i in range(n_classes): - slices[channel_dim] = slice(i, i + 1) + slices[self.channel_dim] = slice(i, i + 1) outputs.append(img[tuple(slices)]) return outputs diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index eea371c689..0db63042a2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -338,7 +338,7 @@ def __init__( self, keys: KeysCollection, output_postfixes: Optional[Sequence[str]] = None, - channel_dim: Optional[int] = None, + channel_dim: int = 0, allow_missing_keys: bool = False, ) -> None: """ @@ -349,10 +349,7 @@ def __init__( for example: if the key of input data is `pred` and split 2 classes, the output data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1]) if None, using the index number: `pred_0`, `pred_1`, ... `pred_N`. - channel_dim: which dimension of input image is the channel, default to None - to automatically select: if data is numpy array, channel_dim is 0 as - `numpy array` is used in the pre transforms, if PyTorch Tensor, channel_dim - is 1 as in most of the cases `Tensor` is uses in the post transforms. + channel_dim: which dimension of input image is the channel, default to 0. allow_missing_keys: don't raise exception if key is missing. """ diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 516c2faa99..5c47240e42 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -589,17 +589,17 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option Gets the largest connected component mask of an image. Args: - img: Image to get largest connected component from. Shape is (batch_size, spatial_dim1 [, spatial_dim2, ...]) + img: Image to get largest connected component from. Shape is (spatial_dim1 [, spatial_dim2, ...]) connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ img_arr = img.detach().cpu().numpy() largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) - for i, item in enumerate(img_arr): - item = measure.label(item, connectivity=connectivity) - if item.max() != 0: - largest_cc[i, ...] = item == (np.argmax(np.bincount(item.flat)[1:]) + 1) + img_arr = measure.label(img_arr, connectivity=connectivity) + if img_arr.max() != 0: + largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) + return torch.as_tensor(largest_cc, device=img.device) diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py index fdf3d735a2..78136f6404 100644 --- a/monai/utils/jupyter_utils.py +++ b/monai/utils/jupyter_utils.py @@ -143,13 +143,13 @@ def tensor_to_images(name: str, tensor: torch.Tensor): color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show each channel separately. """ + if tensor.ndim == 3 and tensor.shape[1] > 2 and tensor.shape[2] > 2: + return tensor.cpu().data.numpy() if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2: - return tuple(tensor[0].cpu().data.numpy()) - if tensor.ndim == 5 and tensor.shape[3] > 2 and tensor.shape[4] > 2: - dmid = tensor.shape[2] // 2 - return tuple(tensor[0, :, dmid].cpu().data.numpy()) + dmid = tensor.shape[1] // 2 + return tensor[:, dmid].cpu().data.numpy() - return () + return None def plot_engine_status( @@ -190,22 +190,22 @@ def plot_engine_status( graphmap.update(logger.metrics) imagemap = {} - if image_fn is not None and engine.state is not None and engine.state.batch is not None: for src in (engine.state.batch, engine.state.output): - if isinstance(src, dict): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - images = image_fn(k, v) - - for i, im in enumerate(images): - imagemap[f"{k}_{i}"] = im - elif isinstance(src, torch.Tensor): - label = "Batch" if src is engine.state.batch else "Output" - images = image_fn(label, src) - - for i, im in enumerate(images): - imagemap[f"{label}_{i}"] = im + if isinstance(src, list): + for i, s in enumerate(src): + if isinstance(s, dict): + for k, v in s.items(): + if isinstance(v, torch.Tensor): + image = image_fn(k, v) + if image is not None: + imagemap[f"{k}_{i}"] = image + else: + label = "Batch" if src is engine.state.batch else "Output" + if isinstance(s, torch.Tensor): + image = image_fn(label, s) + if image is not None: + imagemap[f"{label}_{i}"] = image axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction) @@ -215,11 +215,18 @@ def plot_engine_status( return fig, axes -def _get_loss_from_output(output: Union[Dict[str, torch.Tensor], torch.Tensor]) -> float: +def _get_loss_from_output(output: Union[Dict[str, torch.Tensor], torch.Tensor]): """Returns a single value from the network output, which is a dict or tensor.""" - if isinstance(output, dict): - return output["loss"].item() - return output.item() + + def _get_loss(data): + if isinstance(data, dict): + return data["loss"] + return data + + if isinstance(output, list): + return _get_loss(output[0]) + else: + return _get_loss(output) class StatusMembers(Enum): diff --git a/tests/test_activations.py b/tests/test_activations.py index 5ed9ec2046..7d8b3e4c38 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -19,50 +19,50 @@ TEST_CASE_1 = [ {"sigmoid": True, "softmax": False, "other": None}, - torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), - torch.tensor([[[[0.5000, 0.7311], [0.8808, 0.9526]]]]), - (1, 1, 2, 2), + torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), + torch.tensor([[[0.5000, 0.7311], [0.8808, 0.9526]]]), + (1, 2, 2), ] TEST_CASE_2 = [ {"sigmoid": False, "softmax": True, "other": None}, - torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), - torch.tensor([[[[0.1192, 0.1192]], [[0.8808, 0.8808]]]]), - (1, 2, 1, 2), + torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), + torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + (2, 1, 2), ] TEST_CASE_3 = [ {"sigmoid": False, "softmax": False, "other": torch.tanh}, - torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), - torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]]), - (1, 1, 2, 2), + torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), + torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + (1, 2, 2), ] TEST_CASE_4 = [ "swish", - torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32), + torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), torch.tensor( - [[[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]] + [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] ), - (1, 1, 2, 5), + (1, 2, 5), ] TEST_CASE_5 = [ "memswish", - torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32), + torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), torch.tensor( - [[[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]]] + [[[-4.54e-04, -2.68e-03, -1.48e-02, -7.19e-02, -2.38e-01], [0.00e00, 1.76e00, 3.93e00, 5.99e00, 8.00e00]]] ), - (1, 1, 2, 5), + (1, 2, 5), ] TEST_CASE_6 = [ "mish", - torch.tensor([[[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]]], dtype=torch.float32), + torch.tensor([[[-10, -8, -6, -4, -2], [0, 2, 4, 6, 8]]], dtype=torch.float32), torch.tensor( - [[[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]]] + [[[-4.54e-04, -2.68e-03, -1.49e-02, -7.26e-02, -2.53e-01], [0.00e00, 1.94e00, 4.00e00, 6.00e00, 8.00e00]]] ), - (1, 1, 2, 5), + (1, 2, 5), ] @@ -70,8 +70,16 @@ class TestActivations(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_value_shape(self, input_param, img, out, expected_shape): result = Activations(**input_param)(img) - torch.testing.assert_allclose(result, out) - self.assertTupleEqual(result.shape, expected_shape) + + def _compare(ret, out, shape): + torch.testing.assert_allclose(ret, out) + self.assertTupleEqual(ret.shape, shape) + + if isinstance(result, (list, tuple)): + for r, e in zip(result, out): + _compare(r, e, expected_shape) + else: + _compare(result, out, expected_shape) @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_monai_activations_value_shape(self, input_param, img, out, expected_shape): diff --git a/tests/test_activationsd.py b/tests/test_activationsd.py index f186c17716..355c50f389 100644 --- a/tests/test_activationsd.py +++ b/tests/test_activationsd.py @@ -18,29 +18,29 @@ TEST_CASE_1 = [ {"keys": ["pred", "label"], "sigmoid": False, "softmax": [True, False], "other": None}, - {"pred": torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), "label": torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]])}, + {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, { - "pred": torch.tensor([[[[0.1192, 0.1192]], [[0.8808, 0.8808]]]]), - "label": torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), + "pred": torch.tensor([[[0.1192, 0.1192]], [[0.8808, 0.8808]]]), + "label": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), }, - (1, 2, 1, 2), + (2, 1, 2), ] TEST_CASE_2 = [ {"keys": ["pred", "label"], "sigmoid": False, "softmax": False, "other": [torch.tanh, None]}, - {"pred": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), "label": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]])}, + {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, { - "pred": torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]]), - "label": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), + "pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]]), + "label": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), }, - (1, 1, 2, 2), + (1, 2, 2), ] TEST_CASE_3 = [ {"keys": "pred", "sigmoid": False, "softmax": False, "other": torch.tanh}, - {"pred": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]])}, - {"pred": torch.tensor([[[[0.0000, 0.7616], [0.9640, 0.9951]]]])}, - (1, 1, 2, 2), + {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]])}, + {"pred": torch.tensor([[[0.0000, 0.7616], [0.9640, 0.9951]]])}, + (1, 2, 2), ] diff --git a/tests/test_as_discrete.py b/tests/test_as_discrete.py index 7e3b586cc9..658a21efd6 100644 --- a/tests/test_as_discrete.py +++ b/tests/test_as_discrete.py @@ -18,23 +18,23 @@ TEST_CASE_1 = [ {"argmax": True, "to_onehot": False, "n_classes": None, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), - torch.tensor([[[[1.0, 1.0]]]]), - (1, 1, 1, 2), + torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), + torch.tensor([[[1.0, 1.0]]]), + (1, 1, 2), ] TEST_CASE_2 = [ {"argmax": True, "to_onehot": True, "n_classes": 2, "threshold_values": False, "logit_thresh": 0.5}, - torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), - torch.tensor([[[[0.0, 0.0]], [[1.0, 1.0]]]]), - (1, 2, 1, 2), + torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), + torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), + (2, 1, 2), ] TEST_CASE_3 = [ {"argmax": False, "to_onehot": False, "n_classes": None, "threshold_values": True, "logit_thresh": 0.6}, - torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), - torch.tensor([[[[0.0, 1.0], [1.0, 1.0]]]]), - (1, 1, 2, 2), + torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), + torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), + (1, 2, 2), ] diff --git a/tests/test_as_discreted.py b/tests/test_as_discreted.py index 0b4c483ac6..d6a6f3c2a4 100644 --- a/tests/test_as_discreted.py +++ b/tests/test_as_discreted.py @@ -25,9 +25,9 @@ "threshold_values": False, "logit_thresh": 0.5, }, - {"pred": torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]]), "label": torch.tensor([[[[0, 1]]]])}, - {"pred": torch.tensor([[[[0.0, 0.0]], [[1.0, 1.0]]]]), "label": torch.tensor([[[[1.0, 0.0]], [[0.0, 1.0]]]])}, - (1, 2, 1, 2), + {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]]), "label": torch.tensor([[[0, 1]]])}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]]), "label": torch.tensor([[[1.0, 0.0]], [[0.0, 1.0]]])}, + (2, 1, 2), ] TEST_CASE_2 = [ @@ -39,9 +39,9 @@ "threshold_values": [True, False], "logit_thresh": 0.6, }, - {"pred": torch.tensor([[[[0.0, 1.0], [2.0, 3.0]]]]), "label": torch.tensor([[[[0, 1], [1, 1]]]])}, - {"pred": torch.tensor([[[[0.0, 1.0], [1.0, 1.0]]]]), "label": torch.tensor([[[[0.0, 1.0], [1.0, 1.0]]]])}, - (1, 1, 2, 2), + {"pred": torch.tensor([[[0.0, 1.0], [2.0, 3.0]]]), "label": torch.tensor([[[0, 1], [1, 1]]])}, + {"pred": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]]), "label": torch.tensor([[[0.0, 1.0], [1.0, 1.0]]])}, + (1, 2, 2), ] TEST_CASE_3 = [ @@ -53,9 +53,9 @@ "threshold_values": False, "logit_thresh": 0.5, }, - {"pred": torch.tensor([[[[0.0, 1.0]], [[2.0, 3.0]]]])}, - {"pred": torch.tensor([[[[0.0, 0.0]], [[1.0, 1.0]]]])}, - (1, 2, 1, 2), + {"pred": torch.tensor([[[0.0, 1.0]], [[2.0, 3.0]]])}, + {"pred": torch.tensor([[[0.0, 0.0]], [[1.0, 1.0]]])}, + (2, 1, 2), ] diff --git a/tests/test_compute_roc_auc.py b/tests/test_compute_roc_auc.py index f2eec69532..0030b44fbf 100644 --- a/tests/test_compute_roc_auc.py +++ b/tests/test_compute_roc_auc.py @@ -85,15 +85,19 @@ class TestComputeROCAUC(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_value(self, y_pred, y, softmax, to_onehot, average, expected_value): - y_pred = Activations(softmax=softmax)(y_pred) - y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y) + act = Activations(softmax=softmax) + dis = AsDiscrete(to_onehot=to_onehot, n_classes=2) + y_pred = torch.stack([act(i) for i in y_pred], dim=0) + y = torch.stack([dis(i) for i in y], dim=0) result = compute_roc_auc(y_pred=y_pred, y=y, average=average) np.testing.assert_allclose(expected_value, result, rtol=1e-5) @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) def test_class_value(self, y_pred, y, softmax, to_onehot, average, expected_value): - y_pred = Activations(softmax=softmax)(y_pred) - y = AsDiscrete(to_onehot=to_onehot, n_classes=2)(y) + act = Activations(softmax=softmax) + dis = AsDiscrete(to_onehot=to_onehot, n_classes=2) + y_pred = torch.stack([act(i) for i in y_pred], dim=0) + y = torch.stack([dis(i) for i in y], dim=0) metric = ROCAUCMetric(average=average) metric(y_pred=y_pred, y=y) result = metric.aggregate() diff --git a/tests/test_decollate.py b/tests/test_decollate.py index 813c849fab..7d4532fbfd 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -35,7 +35,7 @@ ToTensor, ToTensord, ) -from monai.transforms.post.dictionary import Decollated +from monai.transforms.inverse_batch_transform import Decollated from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism from monai.utils.enums import InverseKeys diff --git a/tests/test_deepgrow_interaction.py b/tests/test_deepgrow_interaction.py index 3e9c9f285d..77c37bf5f3 100644 --- a/tests/test_deepgrow_interaction.py +++ b/tests/test_deepgrow_interaction.py @@ -59,7 +59,7 @@ def run_interaction(self, train, compose): engine.add_event_handler(IterationEvents.INNER_ITERATION_COMPLETED, add_one) engine.run() - self.assertIsNotNone(engine.state.batch.get("probability"), "Probability is missing") + self.assertIsNotNone(engine.state.batch[0].get("probability"), "Probability is missing") self.assertEqual(engine.state.best_metric, 9) def test_train_interaction(self): diff --git a/tests/test_ensemble_evaluator.py b/tests/test_ensemble_evaluator.py index 28a2d4f941..7f63cb6401 100644 --- a/tests/test_ensemble_evaluator.py +++ b/tests/test_ensemble_evaluator.py @@ -58,10 +58,10 @@ class CustomEvents(EventEnum): ) @val_engine.on(Events.ITERATION_COMPLETED) - def run_post_transform(engine): + def run_transform(engine): for i in range(5): expected_value = engine.state.iteration + i - torch.testing.assert_allclose(engine.state.output[f"pred{i}"], torch.tensor([[expected_value]])) + torch.testing.assert_allclose(engine.state.output[0][f"pred{i}"].item(), expected_value) @val_engine.on(Events.EPOCH_COMPLETED) def trigger_custom_event(): diff --git a/tests/test_handler_classification_saver.py b/tests/test_handler_classification_saver.py index 30c87df98d..87ce5ca3f8 100644 --- a/tests/test_handler_classification_saver.py +++ b/tests/test_handler_classification_saver.py @@ -18,6 +18,7 @@ import torch from ignite.engine import Engine +from monai.data import decollate_batch from monai.data.csv_saver import CSVSaver from monai.handlers import ClassificationSaver @@ -28,7 +29,8 @@ def test_saved_content(self): # set up engine def _train_func(engine, batch): - return torch.zeros(8) + engine.state.batch = decollate_batch(batch) + return [torch.zeros(1) for _ in range(8)] engine = Engine(_train_func) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index 359f55f3d8..70cc0ca42f 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -19,6 +19,7 @@ import torch.distributed as dist from ignite.engine import Engine +from monai.data import decollate_batch from monai.handlers import ClassificationSaver from tests.utils import DistCall, DistTestCase @@ -31,7 +32,8 @@ def test_saved_content(self): # set up engine def _train_func(engine, batch): - return torch.zeros(8 + rank * 2) + engine.state.batch = decollate_batch(batch) + return [torch.zeros(1) for _ in range(8 + rank * 2)] engine = Engine(_train_func) @@ -43,7 +45,7 @@ def _train_func(engine, batch): data = [ { "filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))], - "data_shape": [(1, 1) for _ in range(8 * rank, (8 + rank) * (rank + 1))], + "data_shape": torch.ones((8 + rank * 2, 1, 1)), } ] # rank 1 has more iterations @@ -51,7 +53,7 @@ def _train_func(engine, batch): data.append( { "filename_or_obj": ["testfile" + str(i) for i in range(18, 28)], - "data_shape": [(1, 1) for _ in range(18, 28)], + "data_shape": torch.ones((10, 1, 1)), } ) diff --git a/tests/test_handler_post_processing.py b/tests/test_handler_post_processing.py index 22e6fb0818..89638cdac9 100644 --- a/tests/test_handler_post_processing.py +++ b/tests/test_handler_post_processing.py @@ -20,7 +20,7 @@ # test lambda function as `transform` TEST_CASE_1 = [{"transform": lambda x: dict(pred=x["pred"] + 1.0)}, torch.tensor([[[[1.9975], [1.9997]]]])] -# test composed post transforms as `transform` +# test composed postprocessing transforms as `transform` TEST_CASE_2 = [ { "transform": Compose( @@ -38,24 +38,25 @@ class TestHandlerPostProcessing(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) def test_compute(self, input_params, expected): data = [ - {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": "test1"}, - {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": "test2"}, + {"image": torch.tensor([[[[2.0], [3.0]]]]), "filename": ["test1"]}, + {"image": torch.tensor([[[[6.0], [8.0]]]]), "filename": ["test2"]}, ] - # set up engine, PostProcessing handler works together with post_transform of engine + # set up engine, PostProcessing handler works together with postprocessing transforms of engine engine = SupervisedEvaluator( device=torch.device("cpu:0"), val_data_loader=data, epoch_length=2, network=torch.nn.PReLU(), - post_transform=Compose([Activationsd(keys="pred", sigmoid=True)]), + postprocessing=Compose([Activationsd(keys="pred", sigmoid=True)]), val_handlers=[PostProcessing(**input_params)], ) engine.run() - torch.testing.assert_allclose(engine.state.output["pred"], expected) - filename = engine.state.output.get("filename_bak") - if filename is not None: - self.assertEqual(filename, "test2") + for o, e in zip(engine.state.output, expected): + torch.testing.assert_allclose(o["pred"], e) + filename = o.get("filename_bak") + if filename is not None: + self.assertEqual(filename, "test2") if __name__ == "__main__": diff --git a/tests/test_handler_rocauc.py b/tests/test_handler_rocauc.py index 36bb499cba..46594eb629 100644 --- a/tests/test_handler_rocauc.py +++ b/tests/test_handler_rocauc.py @@ -24,18 +24,17 @@ def test_compute(self): act = Activations(softmax=True) to_onehot = AsDiscrete(to_onehot=True, n_classes=2) - y_pred = torch.Tensor([[0.1, 0.9], [0.3, 1.4]]) - y = torch.Tensor([[0], [1]]) - y_pred = act(y_pred) - y = to_onehot(y) + y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] auc_metric.update([y_pred, y]) - y_pred = torch.Tensor([[0.2, 0.1], [0.1, 0.5]]) - y = torch.Tensor([[0], [1]]) - y_pred = act(y_pred) - y = to_onehot(y) - # test a list of channel-first tensors - y_pred, y = list(y_pred), list(y) + y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])] + y = [torch.Tensor([0]), torch.Tensor([1])] + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] + auc_metric.update([y_pred, y]) auc = auc_metric.compute() diff --git a/tests/test_handler_rocauc_dist.py b/tests/test_handler_rocauc_dist.py index e768906158..e728c80be6 100644 --- a/tests/test_handler_rocauc_dist.py +++ b/tests/test_handler_rocauc_dist.py @@ -30,15 +30,19 @@ def test_compute(self): device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu" if dist.get_rank() == 0: - y_pred = torch.tensor([[0.1, 0.9], [0.3, 1.4]], device=device) - y = torch.tensor([[0], [1]], device=device) + y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device)] if dist.get_rank() == 1: - y_pred = torch.tensor([[0.2, 0.1], [0.1, 0.5], [0.3, 0.4]], device=device) - y = torch.tensor([[0], [1], [1]], device=device) - - y_pred = act(y_pred) - y = to_onehot(y) + y_pred = [ + torch.tensor([0.2, 0.1], device=device), + torch.tensor([0.1, 0.5], device=device), + torch.tensor([0.3, 0.4], device=device), + ] + y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)] + + y_pred = [act(p) for p in y_pred] + y = [to_onehot(y_) for y_ in y] auc_metric.update([y_pred, y]) result = auc_metric.compute() diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 5449530b50..78dea0a68b 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -18,6 +18,7 @@ from ignite.engine import Engine from parameterized import parameterized +from monai.data import decollate_batch from monai.handlers import SegmentationSaver TEST_CASE_0 = [".nii.gz"] @@ -32,7 +33,8 @@ def test_saved_content(self, output_ext): # set up engine def _train_func(engine, batch): - return torch.randint(0, 255, (8, 1, 2, 2)).float() + engine.state.batch = decollate_batch(batch) + return [torch.randint(0, 255, (1, 2, 2)).float() for _ in range(8)] engine = Engine(_train_func) @@ -43,7 +45,7 @@ def _train_func(engine, batch): data = [ { "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "patch_index": list(range(8)), + "patch_index": torch.tensor(list(range(8))), } ] engine.run(data, max_epochs=1) @@ -57,7 +59,8 @@ def test_save_resized_content(self, output_ext): # set up engine def _train_func(engine, batch): - return torch.randint(0, 255, (8, 1, 2, 2)).float() + engine.state.batch = decollate_batch(batch) + return [torch.randint(0, 255, (1, 2, 2)).float() for _ in range(8)] engine = Engine(_train_func) @@ -68,9 +71,9 @@ def _train_func(engine, batch): data = [ { "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + "spatial_shape": torch.tensor([[28, 28] for _ in range(8)]), + "affine": torch.tensor([np.diag(np.ones(4)) * 5 for _ in range(8)]), + "original_affine": torch.tensor([np.diag(np.ones(4)) * 1.0 for _ in range(8)]), } ] engine.run(data, max_epochs=1) diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py index 248be9f329..84cdef59a8 100644 --- a/tests/test_handler_stats.py +++ b/tests/test_handler_stats.py @@ -32,7 +32,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return torch.tensor(0.0) + return [torch.tensor(0.0)] engine = Engine(_train_func) @@ -67,7 +67,7 @@ def test_loss_print(self): # set up engine def _train_func(engine, batch): - return torch.tensor(0.0) + return [torch.tensor(0.0)] engine = Engine(_train_func) @@ -96,7 +96,7 @@ def test_loss_dict(self): # set up engine def _train_func(engine, batch): - return torch.tensor(0.0) + return [torch.tensor(0.0)] engine = Engine(_train_func) @@ -129,7 +129,7 @@ def test_loss_file(self): # set up engine def _train_func(engine, batch): - return torch.tensor(0.0) + return [torch.tensor(0.0)] engine = Engine(_train_func) diff --git a/tests/test_handler_tb_image.py b/tests/test_handler_tb_image.py index f946fb6060..b5d963eedf 100644 --- a/tests/test_handler_tb_image.py +++ b/tests/test_handler_tb_image.py @@ -18,6 +18,7 @@ from ignite.engine import Engine, Events from parameterized import parameterized +from monai.data import decollate_batch from monai.handlers import TensorBoardImageHandler TEST_CASES = [[[20, 20]], [[2, 20, 20]], [[3, 20, 20]], [[20, 20, 20]], [[2, 20, 20, 20]], [[2, 2, 20, 20, 20]]] @@ -30,7 +31,8 @@ def test_tb_image_shape(self, shape): # set up engine def _train_func(engine, batch): - return torch.zeros((1, 1, 10, 10)) + engine.state.batch = decollate_batch(list(batch)) + return [torch.zeros((1, 10, 10))] engine = Engine(_train_func) @@ -38,7 +40,10 @@ def _train_func(engine, batch): stats_handler = TensorBoardImageHandler(log_dir=tempdir) engine.add_event_handler(Events.ITERATION_COMPLETED, stats_handler) - data = zip(np.random.normal(size=(10, 4, *shape)), np.random.normal(size=(10, 4, *shape))) + data = zip( + torch.as_tensor(np.random.normal(size=(10, 4, *shape))), + torch.as_tensor(np.random.normal(size=(10, 4, *shape))), + ) engine.run(data, epoch_length=10, max_epochs=1) stats_handler.close() diff --git a/tests/test_handler_tb_stats.py b/tests/test_handler_tb_stats.py index 0d8654cb09..1d722e7f66 100644 --- a/tests/test_handler_tb_stats.py +++ b/tests/test_handler_tb_stats.py @@ -25,7 +25,7 @@ def test_metrics_print(self): # set up engine def _train_func(engine, batch): - return batch + 1.0 + return [batch + 1.0] engine = Engine(_train_func) @@ -48,7 +48,7 @@ def test_metrics_writer(self): # set up engine def _train_func(engine, batch): - return batch + 1.0 + return [batch + 1.0] engine = Engine(_train_func) @@ -61,7 +61,7 @@ def _update_metric(engine): # set up testing handler writer = SummaryWriter(log_dir=tempdir) stats_handler = TensorBoardStatsHandler( - writer, output_transform=lambda x: {"loss": x * 2.0}, global_epoch_transform=lambda x: x * 3.0 + writer, output_transform=lambda x: {"loss": x[0] * 2.0}, global_epoch_transform=lambda x: x * 3.0 ) stats_handler.attach(engine) engine.run(range(3), max_epochs=2) diff --git a/tests/test_handler_transform_inverter.py b/tests/test_handler_transform_inverter.py index d4713072ff..f2e75a7153 100644 --- a/tests/test_handler_transform_inverter.py +++ b/tests/test_handler_transform_inverter.py @@ -16,7 +16,7 @@ import torch from ignite.engine import Engine -from monai.data import CacheDataset, DataLoader, create_test_image_3d, pad_list_data_collate +from monai.data import CacheDataset, DataLoader, create_test_image_3d, decollate_batch from monai.engines.utils import IterationEvents from monai.handlers import TransformInverter from monai.transforms import ( @@ -78,7 +78,7 @@ def test_invert(self): # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) - engine.state.output = batch + engine.state.output = engine.state.batch = decollate_batch(batch) engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output @@ -88,7 +88,6 @@ def _train_func(engine, batch): # set up testing handler TransformInverter( transform=transform, - loader=loader, output_keys=["image_inverted1", "label_inverted1"], batch_keys="label", meta_keys=["image_inverted1_meta_dict", "label_inverted1_meta_dict"], @@ -102,7 +101,6 @@ def _train_func(engine, batch): # test different nearest interpolation values TransformInverter( transform=transform, - loader=loader, output_keys=["image_inverted2", "label_inverted2"], batch_keys="image", meta_keys=None, @@ -110,27 +108,39 @@ def _train_func(engine, batch): meta_key_postfix="meta_dict", nearest_interp=[True, False], post_func=[lambda x: x + 10, lambda x: x], - collate_fn=pad_list_data_collate, num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) - self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) - self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) - # check the nearest inerpolation mode - for i in engine.state.output["image_inverted1"]: + + for output in engine.state.output: + self.assertTupleEqual(output["image"].shape, (1, 100, 100, 100)) + self.assertTupleEqual(output["label"].shape, (1, 100, 100, 100)) + # check the nearest inerpolation mode + i = output["image_inverted1"] torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - for i in engine.state.output["label_inverted1"]: + i = output["label_inverted1"] np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + # check the case that different items use different interpolation mode to invert transforms + d = output["image_inverted2"] + # if the interpolation mode is nearest, accumulated diff should be smaller than 1 + self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + d = output["label_inverted2"] + # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 + self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + # check labels match - reverted = engine.state.output["label_inverted1"][-1].astype(np.int32) + reverted = engine.state.output[-1]["label_inverted1"].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - reverted_name = engine.state.batch["label_inverted1_meta_dict"][-1]["filename_or_obj"] + reverted_name = engine.state.batch[-1]["label_inverted1_meta_dict"]["filename_or_obj"] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) @@ -139,17 +149,6 @@ def _train_func(engine, batch): # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") - # check the case that different items use different interpolation mode to invert transforms - d = engine.state.output["image_inverted2"] - # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107)) - - d = engine.state.output["label_inverted2"] - # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - self.assertTupleEqual(d.shape, (2, 1, 100, 101, 107)) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_integration_classification_2d.py b/tests/test_integration_classification_2d.py index f8f4ffdc89..fd13218795 100644 --- a/tests/test_integration_classification_2d.py +++ b/tests/test_integration_classification_2d.py @@ -128,8 +128,8 @@ def run_training_test(root_dir, train_x, train_y, val_x, val_y, device="cuda:0", acc_value = torch.eq(y_pred.argmax(dim=1), y) acc_metric = acc_value.sum().item() / len(acc_value) # compute AUC - y_pred = act(y_pred) - y = to_onehot(y) + y_pred = torch.stack([act(i) for i in y_pred]) + y = torch.stack([to_onehot(i) for i in y]) auc_metric = compute_roc_auc(y_pred, y) metric_values.append(auc_metric) if auc_metric > best_metric: diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index ce27649d54..cde6027003 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -277,7 +277,7 @@ def test_training(self): np.testing.assert_allclose(repeated[0], repeated[2]) np.testing.assert_allclose(repeated[0], repeated[3]) - @TimedCall(seconds=180, daemon=False) + @TimedCall(seconds=360, daemon=False) def test_timing(self): self.train_and_infer(idx=3) diff --git a/tests/test_integration_sliding_window.py b/tests/test_integration_sliding_window.py index a87c59bbab..b63f331ba6 100644 --- a/tests/test_integration_sliding_window.py +++ b/tests/test_integration_sliding_window.py @@ -75,7 +75,7 @@ def tearDown(self): if os.path.exists(self.seg_name): os.remove(self.seg_name) - @TimedCall(seconds=10) + @TimedCall(seconds=20) def test_training(self): set_determinism(seed=0) with tempfile.TemporaryDirectory() as tempdir: diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 2184c29b99..36681838d1 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -37,6 +37,7 @@ TensorBoardImageHandler, TensorBoardStatsHandler, ValidationHandler, + from_engine, ) from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.transforms import ( @@ -109,7 +110,7 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): lr_scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.1) summary_writer = SummaryWriter(log_dir=root_dir) - val_post_transforms = Compose( + val_postprocessing = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), @@ -128,7 +129,7 @@ def _forward_completed(self, engine): StatsHandler(output_transform=lambda x: None), TensorBoardStatsHandler(summary_writer=summary_writer, output_transform=lambda x: None), TensorBoardImageHandler( - log_dir=root_dir, batch_transform=lambda x: (x["image"], x["label"]), output_transform=lambda x: x["pred"] + log_dir=root_dir, batch_transform=from_engine(["image", "label"]), output_transform=from_engine("pred") ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net}, save_key_metric=True), _TestEvalIterEvents(), @@ -139,16 +140,16 @@ def _forward_completed(self, engine): val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), - post_transform=val_post_transforms, + postprocessing=val_postprocessing, key_val_metric={ - "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) + "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, - additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, val_handlers=val_handlers, amp=True if amp else False, ) - train_post_transforms = Compose( + train_postprocessing = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), @@ -178,9 +179,9 @@ def _model_completed(self, engine): train_handlers = [ LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True), ValidationHandler(validator=evaluator, interval=2, epoch_level=True), - StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]), + StatsHandler(tag_name="train_loss", output_transform=lambda x: x[0]["loss"]), TensorBoardStatsHandler( - summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x["loss"] + summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x[0]["loss"] ), CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True), _TestTrainIterEvents(), @@ -194,8 +195,8 @@ def _model_completed(self, engine): optimizer=opt, loss_function=loss, inferer=SimpleInferer(), - post_transform=train_post_transforms, - key_train_metric={"train_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + postprocessing=train_postprocessing, + key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, train_handlers=train_handlers, amp=True if amp else False, ) @@ -233,7 +234,7 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor num_res_units=2, ).to(device) - val_post_transforms = Compose( + val_postprocessing = Compose( [ Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold_values=True), @@ -244,7 +245,6 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor meta_keys="image_meta_dict", output_dir=root_dir, output_postfix="seg_transform", - save_batch=True, ), ] ) @@ -254,8 +254,8 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor SegmentationSaver( output_dir=root_dir, output_postfix="seg_handler", - batch_transform=lambda batch: batch["image_meta_dict"], - output_transform=lambda output: output["pred"], + batch_transform=from_engine("image_meta_dict"), + output_transform=from_engine("pred"), ), ] @@ -264,11 +264,11 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor val_data_loader=val_loader, network=net, inferer=SlidingWindowInferer(roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5), - post_transform=val_post_transforms, + postprocessing=val_postprocessing, key_val_metric={ - "val_mean_dice": MeanDice(include_background=True, output_transform=lambda x: (x["pred"], x["label"])) + "val_mean_dice": MeanDice(include_background=True, output_transform=from_engine(["pred", "label"])) }, - additional_metrics={"val_acc": Accuracy(output_transform=lambda x: (x["pred"], x["label"]))}, + additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))}, val_handlers=val_handlers, amp=True if amp else False, ) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 98caff35a1..613199862b 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -15,7 +15,7 @@ import numpy as np import torch -from monai.data import CacheDataset, DataLoader, create_test_image_3d +from monai.data import CacheDataset, DataLoader, create_test_image_3d, decollate_batch from monai.transforms import ( AddChanneld, CastToTyped, @@ -75,7 +75,6 @@ def test_invert(self): # `image` was not copied, invert the original value directly keys=["image", "label_inverted"], transform=transform, - loader=loader, orig_keys="label", meta_keys=["image_meta_dict", "label_inverted_meta_dict"], orig_meta_keys="label_meta_dict", @@ -87,16 +86,18 @@ def test_invert(self): # execute 1 epoch for d in loader: - d = inverter(d) - # this unit test only covers basic function, test_handler_transform_inverter covers more - self.assertTupleEqual(d["label"].shape[1:], (1, 100, 100, 100)) - # check the nearest inerpolation mode - for i in d["image"]: + d = decollate_batch(d) + for item in d: + item = inverter(item) + # this unit test only covers basic function, test_handler_transform_inverter covers more + self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) + # check the nearest inerpolation mode + i = item["image"] torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) - for i in d["label_inverted"]: + self.assertTupleEqual(i.shape[1:], (100, 101, 107)) + i = item["label_inverted"] np.testing.assert_allclose(i.astype(np.uint8).astype(np.float32), i.astype(np.float32)) - self.assertTupleEqual(i.shape, (1, 100, 101, 107)) + self.assertTupleEqual(i.shape[1:], (100, 101, 107)) set_determinism(seed=None) diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 773ca4ad0b..a8835329ba 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -16,57 +16,57 @@ from monai.transforms import KeepLargestConnectedComponent -grid_1 = torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]) -grid_2 = torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]) +grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) +grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) grid_3 = torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], - ] + ], +) +grid_4 = torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ], ) @@ -74,70 +74,70 @@ "value_1", {"independent": False, "applied_labels": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_2 = [ "value_2", {"independent": False, "applied_labels": [2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_3 = [ "independent_value_1_2", {"independent": True, "applied_labels": [1, 2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_4 = [ "dependent_value_1_2", {"independent": False, "applied_labels": [1, 2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_5 = [ "value_1", {"independent": True, "applied_labels": [1]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), ] TEST_CASE_6 = [ "independent_value_1_2", {"independent": True, "applied_labels": [1, 2]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), ] TEST_CASE_7 = [ "dependent_value_1_2", {"independent": False, "applied_labels": [1, 2]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), ] TEST_CASE_8 = [ "value_1_connect_1", {"independent": False, "applied_labels": [1], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_9 = [ "independent_value_1_2_connect_1", {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_10 = [ "dependent_value_1_2_connect_1", {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_11 = [ @@ -147,52 +147,27 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], - ] + ], ), ] @@ -203,52 +178,27 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], - ] + ], ), ] @@ -259,164 +209,89 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ], ), ] TEST_CASE_14 = [ "onehot_dependent_batch_2_apply_label_1_2_connect_2", {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, + grid_4, torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ], - ] + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ], ), ] TEST_CASE_15 = [ "onehot_dependent_batch_2_apply_label_1_2_connect_1", {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_3, + grid_4, torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + ], ), ] diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 7298b91e4f..9478cfb965 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -16,62 +16,60 @@ from monai.transforms import KeepLargestConnectedComponentd -grid_1 = { - "img": torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]) -} -grid_2 = { - "img": torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]) -} +grid_1 = {"img": torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]])} +grid_2 = {"img": torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]])} grid_3 = { "img": torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ], + ) +} +grid_4 = { + "img": torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + ], ) } @@ -79,70 +77,70 @@ "value_1", {"keys": ["img"], "independent": False, "applied_labels": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_2 = [ "value_2", {"keys": ["img"], "independent": False, "applied_labels": [2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_3 = [ "independent_value_1_2", {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_4 = [ "dependent_value_1_2", {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_5 = [ "value_1", {"keys": ["img"], "independent": True, "applied_labels": [1]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), ] TEST_CASE_6 = [ "independent_value_1_2", {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), ] TEST_CASE_7 = [ "dependent_value_1_2", {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, grid_2, - torch.tensor([[[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]]), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), ] TEST_CASE_8 = [ "value_1_connect_1", {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), ] TEST_CASE_9 = [ "independent_value_1_2_connect_1", {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_10 = [ "dependent_value_1_2_connect_1", {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, grid_1, - torch.tensor([[[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]]), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] TEST_CASE_11 = [ @@ -152,52 +150,27 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], ], - ] + ], ), ] @@ -208,52 +181,27 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ], ), ] @@ -264,164 +212,89 @@ torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], ], - ] + ], ), ] TEST_CASE_14 = [ "onehot_dependent_batch_2_apply_label_1_2_connect_2", {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, + grid_4, torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + ], ), ] TEST_CASE_15 = [ "onehot_dependent_batch_2_apply_label_1_2_connect_1", {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_3, + grid_4, torch.tensor( [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], ], - ] + ], ), ] diff --git a/tests/test_label_to_contour.py b/tests/test_label_to_contour.py index b118b91999..8f8f3cc054 100644 --- a/tests/test_label_to_contour.py +++ b/tests/test_label_to_contour.py @@ -147,30 +147,30 @@ def test_contour(self): # check 5-dim input data test_cube, expected_output = gen_fixed_cube() - test_result_cube = LabelToContour(**input_param)(test_cube) - self.assertEqual(test_result_cube.shape, test_cube.shape) + for cube in test_cube: + test_result_cube = LabelToContour(**input_param)(cube) + self.assertEqual(test_result_cube.shape, cube.shape) - test_result_np = test_result_cube.data.cpu().numpy() - batch_size, channels = test_cube.shape[0], test_cube.shape[1] - for batch in range(batch_size): + test_result_np = test_result_cube.cpu().numpy() + channels = cube.shape[0] for channel in range(channels): - np.testing.assert_allclose(test_result_np[batch, channel, ...], expected_output) + np.testing.assert_allclose(test_result_np[channel, ...], expected_output) # check 4-dim input data test_img, expected_output = gen_fixed_img() - batch_size, channels = test_img.shape[0], test_img.shape[1] - test_result_img = LabelToContour(**input_param)(test_img) - self.assertEqual(test_result_img.shape, test_img.shape) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContour(**input_param)(img) + self.assertEqual(test_result_img.shape, img.shape) - test_result_np = test_result_img.data.cpu().numpy() - for batch in range(batch_size): + test_result_np = test_result_img.cpu().numpy() for channel in range(channels): - np.testing.assert_allclose(test_result_img[batch, channel, ...], expected_output) + np.testing.assert_allclose(test_result_np[channel, ...], expected_output) # check invalid input data - error_input = torch.rand(1, 2, 3) + error_input = torch.rand(1, 2) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) - error_input = torch.rand(1, 2, 3, 4, 5, 6) + error_input = torch.rand(1, 2, 3, 4, 5) self.assertRaises(ValueError, LabelToContour(**input_param), error_input) diff --git a/tests/test_label_to_contourd.py b/tests/test_label_to_contourd.py index aa4dffe03e..d3795755c7 100644 --- a/tests/test_label_to_contourd.py +++ b/tests/test_label_to_contourd.py @@ -147,30 +147,30 @@ def test_contour(self): # check 5-dim input data test_cube, expected_output = gen_fixed_cube() - test_result_cube = LabelToContourd(**input_param)({"img": test_cube}) - self.assertEqual(test_result_cube["img"].shape, test_cube.shape) + for cube in test_cube: + test_result_cube = LabelToContourd(**input_param)({"img": cube}) + self.assertEqual(test_result_cube["img"].shape, cube.shape) - test_result_np = test_result_cube["img"].data.cpu().numpy() - batch_size, channels = test_cube.shape[0], test_cube.shape[1] - for batch in range(batch_size): + test_result_np = test_result_cube["img"].cpu().numpy() + channels = cube.shape[0] for channel in range(channels): - np.testing.assert_allclose(test_result_np[batch, channel, ...], expected_output) + np.testing.assert_allclose(test_result_np[channel, ...], expected_output) # check 4-dim input data test_img, expected_output = gen_fixed_img() - batch_size, channels = test_img.shape[0], test_img.shape[1] - test_result_img = LabelToContourd(**input_param)({"img": test_img}) - self.assertEqual(test_result_img["img"].shape, test_img.shape) + for img in test_img: + channels = img.shape[0] + test_result_img = LabelToContourd(**input_param)({"img": img}) + self.assertEqual(test_result_img["img"].shape, img.shape) - test_result_np = test_result_img["img"].data.cpu().numpy() - for batch in range(batch_size): + test_result_np = test_result_img["img"].cpu().numpy() for channel in range(channels): - np.testing.assert_allclose(test_result_img["img"][batch, channel, ...], expected_output) + np.testing.assert_allclose(test_result_np[channel, ...], expected_output) # check invalid input data - error_input = {"img": torch.rand(1, 2, 3)} + error_input = {"img": torch.rand(1, 2)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) - error_input = {"img": torch.rand(1, 2, 3, 4, 5, 6)} + error_input = {"img": torch.rand(1, 2, 3, 4, 5)} self.assertRaises(ValueError, LabelToContourd(**input_param), error_input) diff --git a/tests/test_mean_ensemble.py b/tests/test_mean_ensemble.py index 32a6856263..7e08846beb 100644 --- a/tests/test_mean_ensemble.py +++ b/tests/test_mean_ensemble.py @@ -19,32 +19,32 @@ TEST_CASE_1 = [ {"weights": None}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) + 1, + [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], + torch.ones(2, 2, 2) + 1, ] TEST_CASE_2 = [ {"weights": None}, - torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2]), - torch.ones(2, 2, 2, 2) + 1, + torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2]), + torch.ones(2, 2, 2) + 1, ] TEST_CASE_3 = [ {"weights": [1, 3]}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) * 2.5, + [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], + torch.ones(2, 2, 2) * 2.5, ] TEST_CASE_4 = [ - {"weights": [[[1, 3]], [[3, 1]]]}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + {"weights": [[1, 3], [3, 1]]}, + [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], + torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), ] TEST_CASE_5 = [ - {"weights": np.array([[[1, 3]], [[3, 1]]])}, - [torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2], - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + {"weights": np.array([[1, 3], [3, 1]])}, + [torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2], + torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), ] TEST_CASE_6 = [ diff --git a/tests/test_mean_ensembled.py b/tests/test_mean_ensembled.py index c7549e5aa4..ea77ef18a0 100644 --- a/tests/test_mean_ensembled.py +++ b/tests/test_mean_ensembled.py @@ -19,14 +19,14 @@ TEST_CASE_1 = [ {"keys": ["pred0", "pred1"], "output_key": "output", "weights": None}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) + 1, + {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, + torch.ones(2, 2, 2) + 1, ] TEST_CASE_2 = [ {"keys": "output", "weights": None}, - {"output": torch.stack([torch.ones(2, 2, 2, 2), torch.ones(2, 2, 2, 2) + 2])}, - torch.ones(2, 2, 2, 2) + 1, + {"output": torch.stack([torch.ones(2, 2, 2), torch.ones(2, 2, 2) + 2])}, + torch.ones(2, 2, 2) + 1, ] TEST_CASE_3 = [ @@ -36,9 +36,9 @@ ] TEST_CASE_4 = [ - {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[[1, 3]], [[3, 1]]]}, - {"pred0": torch.ones(2, 2, 2, 2), "pred1": torch.ones(2, 2, 2, 2) + 2}, - torch.ones(2, 2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(1, 2, 1, 1), + {"keys": ["pred0", "pred1"], "output_key": "output", "weights": [[1, 3], [3, 1]]}, + {"pred0": torch.ones(2, 2, 2), "pred1": torch.ones(2, 2, 2) + 2}, + torch.ones(2, 2, 2) * torch.tensor([2.5, 1.5]).reshape(2, 1, 1), ] TEST_CASE_5 = [ diff --git a/tests/test_save_classificationd.py b/tests/test_save_classificationd.py index de2b5fbb30..67dc0320a6 100644 --- a/tests/test_save_classificationd.py +++ b/tests/test_save_classificationd.py @@ -17,7 +17,7 @@ import numpy as np import torch -from monai.data import CSVSaver +from monai.data import CSVSaver, decollate_batch from monai.transforms import Compose, CopyItemsd, SaveClassificationd @@ -58,20 +58,27 @@ def test_saved_content(self): ] ) # simulate inference 2 iterations - post_trans(data[0]) - post_trans(data[1]) + d = decollate_batch(data[0]) + for i in d: + post_trans(i) + d = decollate_batch(data[1]) + for i in d: + post_trans(i) # write into CSV file saver.finalize() # 3rd saver will not delete previous data due to `overwrite=False` - SaveClassificationd( + trans2 = SaveClassificationd( keys="pred", saver=None, meta_keys="image_meta_dict", # specify meta key, so no need to copy anymore output_dir=tempdir, filename="predictions1.csv", overwrite=False, - )(data[2]) + ) + d = decollate_batch(data[2]) + for i in d: + trans2(i) def _test_file(filename, count): filepath = os.path.join(tempdir, filename) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index a279d3a4ec..b50cb083ba 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -13,127 +13,42 @@ import tempfile import unittest -import numpy as np import torch from parameterized import parameterized from monai.transforms import SaveImage -TEST_CASE_0 = [ - torch.randint(0, 255, (8, 1, 2, 3, 4)), - {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, - ".nii.gz", - False, - True, -] - TEST_CASE_1 = [ - torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), - {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, - ".png", - False, - True, -] - -TEST_CASE_2 = [ - np.random.randint(0, 255, (8, 1, 2, 3, 4)), - {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, - ".nii.gz", - False, - True, -] - -TEST_CASE_3 = [ - torch.randint(0, 255, (8, 1, 2, 2)), - { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - }, - ".nii.gz", - True, - True, -] - -TEST_CASE_4 = [ - torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), - { - "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - }, - ".png", - True, - True, -] - -TEST_CASE_5 = [ torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nii.gz"}, ".nii.gz", False, - False, ] -TEST_CASE_6 = [ +TEST_CASE_2 = [ torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False, - False, -] - -TEST_CASE_7 = [ - [torch.randint(0, 255, (1, 2, 3, 4)), torch.randint(0, 255, (1, 2, 3, 4))], - [{"filename_or_obj": "testfile0.nii.gz"}, {"filename_or_obj": "testfile1.nii.gz"}], - ".nii.gz", - False, - False, -] - -TEST_CASE_8 = [ - [torch.randint(0, 255, (1, 2, 3, 4))], - {"filename_or_obj": ["testfile0.nii.gz"]}, - ".nii.gz", - False, - False, ] class TestSaveImage(unittest.TestCase): - @parameterized.expand( - [ - TEST_CASE_0, - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - ] - ) - def test_saved_content(self, test_data, meta_data, output_ext, resample, save_batch): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( output_dir=tempdir, output_ext=output_ext, resample=resample, - save_batch=save_batch, ) trans(test_data, meta_data) - if save_batch: - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + if meta_data is not None: + filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) else: - if meta_data is not None: - filepath = os.path.join("testfile0", "testfile0" + "_trans" + output_ext) - else: - filepath = os.path.join("0", "0" + "_trans" + output_ext) - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + filepath = os.path.join("0", "0" + "_trans" + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) if __name__ == "__main__": diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index ed240e1113..35bbea9628 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -13,81 +13,21 @@ import tempfile import unittest -import numpy as np import torch from parameterized import parameterized from monai.transforms import SaveImaged -TEST_CASE_0 = [ - { - "img": torch.randint(0, 255, (8, 1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, - }, - ".nii.gz", - False, - True, -] - TEST_CASE_1 = [ - { - "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), - "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)]}, - }, - ".png", - False, - True, -] - -TEST_CASE_2 = [ - { - "img": np.random.randint(0, 255, (8, 1, 2, 3, 4)), - "img_meta_dict": {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}, - }, - ".nii.gz", - False, - True, -] - -TEST_CASE_3 = [ - { - "img": torch.randint(0, 255, (8, 1, 2, 2)), - "img_meta_dict": { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - }, - }, - ".nii.gz", - True, - True, -] - -TEST_CASE_4 = [ - { - "img": torch.randint(0, 255, (8, 1, 2, 3), dtype=torch.uint8), - "img_meta_dict": { - "filename_or_obj": ["testfile" + str(i) + ".png" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - }, - }, - ".png", - True, - True, -] - -TEST_CASE_5 = [ { "img": torch.randint(0, 255, (1, 2, 3, 4)), "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, }, ".nii.gz", False, - False, ] -TEST_CASE_6 = [ +TEST_CASE_2 = [ { "img": torch.randint(0, 255, (1, 2, 3, 4)), "img_meta_dict": {"filename_or_obj": "testfile0.nii.gz"}, @@ -95,28 +35,12 @@ }, ".nii.gz", False, - False, -] - -TEST_CASE_7 = [ - { - "pred": torch.randint(0, 255, (8, 1, 2, 2)), - "img_meta_dict": { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(28, 28)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - }, - }, - ".nii.gz", - True, - True, ] class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) - def test_saved_content(self, test_data, output_ext, resample, save_batch): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( keys=["img", "pred"], @@ -124,20 +48,14 @@ def test_saved_content(self, test_data, output_ext, resample, save_batch): output_dir=tempdir, output_ext=output_ext, resample=resample, - save_batch=save_batch, allow_missing_keys=True, ) trans(test_data) - if save_batch: - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_trans" + output_ext) - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - else: - patch_index = test_data["img_meta_dict"].get("patch_index", None) - patch_index = f"_{patch_index}" if patch_index is not None else "" - filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext) - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + patch_index = test_data["img_meta_dict"].get("patch_index", None) + patch_index = f"_{patch_index}" if patch_index is not None else "" + filepath = os.path.join("testfile0", "testfile0" + "_trans" + patch_index + output_ext) + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) if __name__ == "__main__": diff --git a/tests/test_split_channel.py b/tests/test_split_channel.py index 8eec3c4e70..91e93aedcc 100644 --- a/tests/test_split_channel.py +++ b/tests/test_split_channel.py @@ -17,21 +17,17 @@ from monai.transforms import SplitChannel -TEST_CASE_1 = [{"channel_dim": None}, torch.randint(0, 2, size=(4, 3, 3, 4)), (4, 1, 3, 4)] +TEST_CASE_1 = [{"channel_dim": 1}, torch.randint(0, 2, size=(4, 3, 3, 4)), (4, 1, 3, 4)] -TEST_CASE_2 = [{"channel_dim": 1}, torch.randint(0, 2, size=(4, 3, 3, 4)), (4, 1, 3, 4)] +TEST_CASE_2 = [{"channel_dim": 0}, np.random.randint(2, size=(3, 3, 4)), (1, 3, 4)] -TEST_CASE_3 = [{"channel_dim": None}, np.random.randint(2, size=(3, 3, 4)), (1, 3, 4)] +TEST_CASE_3 = [{"channel_dim": 2}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] -TEST_CASE_4 = [{"channel_dim": 0}, np.random.randint(2, size=(3, 3, 4)), (1, 3, 4)] - -TEST_CASE_5 = [{"channel_dim": 2}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] - -TEST_CASE_6 = [{"channel_dim": -1}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] +TEST_CASE_4 = [{"channel_dim": -1}, np.random.randint(2, size=(3, 2, 4)), (3, 2, 1)] class TestSplitChannel(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, test_data, expected_shape): result = SplitChannel(**input_param)(test_data) for data in result: diff --git a/tests/test_split_channeld.py b/tests/test_split_channeld.py index 814ef69922..57c7099b9f 100644 --- a/tests/test_split_channeld.py +++ b/tests/test_split_channeld.py @@ -18,42 +18,30 @@ from monai.transforms import SplitChanneld TEST_CASE_1 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": None}, - {"pred": torch.randint(0, 2, size=(4, 3, 3, 4))}, - (4, 1, 3, 4), -] - -TEST_CASE_2 = [ {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 1}, {"pred": torch.randint(0, 2, size=(4, 3, 3, 4))}, (4, 1, 3, 4), ] -TEST_CASE_3 = [ - {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": None}, - {"pred": np.random.randint(2, size=(3, 3, 4))}, - (1, 3, 4), -] - -TEST_CASE_4 = [ +TEST_CASE_2 = [ {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3"], "channel_dim": 0}, {"pred": np.random.randint(2, size=(3, 3, 4))}, (1, 3, 4), ] -TEST_CASE_5 = [ +TEST_CASE_3 = [ {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": 2}, {"pred": np.random.randint(2, size=(3, 2, 4))}, (3, 2, 1), ] -TEST_CASE_6 = [ +TEST_CASE_4 = [ {"keys": "pred", "output_postfixes": ["cls1", "cls2", "cls3", "cls4"], "channel_dim": -1}, {"pred": np.random.randint(2, size=(3, 2, 4))}, (3, 2, 1), ] -TEST_CASE_7 = [ +TEST_CASE_5 = [ {"keys": "pred", "channel_dim": 1}, {"pred": np.random.randint(2, size=(3, 2, 4))}, (3, 1, 4), @@ -61,7 +49,7 @@ class TestSplitChanneld(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, input_param, test_data, expected_shape): result = SplitChanneld(**input_param)(test_data) for k, v in result.items(): diff --git a/tests/test_vote_ensemble.py b/tests/test_vote_ensemble.py index 92039fe103..74c19d5f48 100644 --- a/tests/test_vote_ensemble.py +++ b/tests/test_vote_ensemble.py @@ -16,11 +16,11 @@ from monai.transforms import VoteEnsemble -# shape: [1, 2, 1, 1] +# shape: [2, 1, 1] TEST_CASE_1 = [ {"num_classes": None}, - [torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[1]], [[0]]]]), torch.tensor([[[[0]], [[1]]]])], - torch.tensor([[[[1.0]], [[0.0]]]]), + [torch.tensor([[[1]], [[0]]]), torch.tensor([[[1]], [[0]]]), torch.tensor([[[0]], [[1]]])], + torch.tensor([[[1.0]], [[0.0]]]), ] # shape: [1, 2, 1, 1] @@ -30,30 +30,37 @@ torch.tensor([[[[1.0]], [[0.0]]]]), ] -# shape: [1, 1, 2, 1] +# shape: [1, 2, 1] TEST_CASE_3 = [ {"num_classes": 3}, - [torch.tensor([[[[0], [2]]]]), torch.tensor([[[[0], [2]]]]), torch.tensor([[[[1], [1]]]])], - torch.tensor([[[[0], [2]]]]), + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), ] -# shape: [1, 1, 2, 1] +# shape: [1, 2, 1] TEST_CASE_4 = [ {"num_classes": 5}, - [torch.tensor([[[[0], [2]]]]), torch.tensor([[[[0], [2]]]]), torch.tensor([[[[1], [1]]]])], - torch.tensor([[[[0], [2]]]]), + [torch.tensor([[[0], [2]]]), torch.tensor([[[0], [2]]]), torch.tensor([[[1], [1]]])], + torch.tensor([[[0], [2]]]), ] -# shape: [2] +# shape: [1] TEST_CASE_5 = [ {"num_classes": 3}, - [torch.tensor([0, 2]), torch.tensor([0, 2]), torch.tensor([1, 1])], - torch.tensor([0, 2]), + [torch.tensor([2]), torch.tensor([2]), torch.tensor([1])], + torch.tensor([2]), +] + +# shape: 1 +TEST_CASE_6 = [ + {"num_classes": 3}, + [torch.tensor(2), torch.tensor(2), torch.tensor(1)], + torch.tensor(2), ] class TestVoteEnsemble(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6]) def test_value(self, input_param, img, expected_value): result = VoteEnsemble(**input_param)(img) torch.testing.assert_allclose(result, expected_value) diff --git a/tests/test_vote_ensembled.py b/tests/test_vote_ensembled.py index f4b93c7887..e94213733f 100644 --- a/tests/test_vote_ensembled.py +++ b/tests/test_vote_ensembled.py @@ -38,33 +38,33 @@ torch.tensor([[[[1.0]], [[0.0]]]]), ] -# shape: [1, 1, 2, 1] +# shape: [1, 2, 1] TEST_CASE_3 = [ {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, { - "pred0": torch.tensor([[[[0], [2]]]]), - "pred1": torch.tensor([[[[0], [2]]]]), - "pred2": torch.tensor([[[[1], [1]]]]), + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), }, - torch.tensor([[[[0], [2]]]]), + torch.tensor([[[0], [2]]]), ] -# shape: [1, 1, 2, 1] +# shape: [1, 2, 1] TEST_CASE_4 = [ {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 5}, { - "pred0": torch.tensor([[[[0], [2]]]]), - "pred1": torch.tensor([[[[0], [2]]]]), - "pred2": torch.tensor([[[[1], [1]]]]), + "pred0": torch.tensor([[[0], [2]]]), + "pred1": torch.tensor([[[0], [2]]]), + "pred2": torch.tensor([[[1], [1]]]), }, - torch.tensor([[[[0], [2]]]]), + torch.tensor([[[0], [2]]]), ] -# shape: [2] +# shape: [1] TEST_CASE_5 = [ {"keys": ["pred0", "pred1", "pred2"], "output_key": "output", "num_classes": 3}, - {"pred0": torch.tensor([0, 2]), "pred1": torch.tensor([0, 2]), "pred2": torch.tensor([1, 1])}, - torch.tensor([0, 2]), + {"pred0": torch.tensor([2]), "pred1": torch.tensor([2]), "pred2": torch.tensor([1])}, + torch.tensor([2]), ]