diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 516e2d4743..e13848f8fc 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -75,7 +75,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -123,7 +123,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -228,7 +228,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml index 70c3153076..d8623c8087 100644 --- a/.github/workflows/pythonapp-gpu.yml +++ b/.github/workflows/pythonapp-gpu.yml @@ -137,6 +137,6 @@ jobs: shell: bash - name: Upload coverage if: ${{ github.head_ref != 'dev' && github.event.pull_request.merged != true }} - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: files: ./coverage.xml diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml index 7e01f55cd9..d9ce9976b8 100644 --- a/.github/workflows/setupapp.yml +++ b/.github/workflows/setupapp.yml @@ -72,7 +72,7 @@ jobs: if pgrep python; then pkill python; fi shell: bash - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml @@ -119,7 +119,7 @@ jobs: BUILD_MONAI=1 ./runtests.sh --build --quick --min coverage xml --ignore-errors - name: Upload coverage - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 with: fail_ci_if_error: false files: ./coverage.xml diff --git a/.gitignore b/.gitignore index 437677d2bb..76c6ab0d12 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd # clang format tool .clang-format-bin/ +# ctags +tags + # VSCode .vscode/ *.zip diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 64a3a4c9d1..e2e509a99b 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -630,6 +630,11 @@ Nets .. autoclass:: ViTAutoEnc :members: +`MaskedAutoEncoderViT` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: MaskedAutoEncoderViT + :members: + `FullyConnectedNet` ~~~~~~~~~~~~~~~~~~~ .. autoclass:: FullyConnectedNet diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 41bb4ae79a..d2585daf63 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1180,6 +1180,18 @@ Utility :members: :special-members: __call__ +`TorchIO` +""""""""" +.. autoclass:: TorchIO + :members: + :special-members: __call__ + +`RandTorchIO` +""""""""""""" +.. autoclass:: RandTorchIO + :members: + :special-members: __call__ + `MapLabelValue` """"""""""""""" .. autoclass:: MapLabelValue @@ -2253,6 +2265,18 @@ Utility (Dict) :members: :special-members: __call__ +`TorchIOd` +"""""""""" +.. autoclass:: TorchIOd + :members: + :special-members: __call__ + +`RandTorchIOd` +"""""""""""""" +.. autoclass:: RandTorchIOd + :members: + :special-members: __call__ + `MapLabelValued` """""""""""""""" .. autoclass:: MapLabelValued diff --git a/environment-dev.yml b/environment-dev.yml index a4651ec7e4..4a1723e8a5 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,6 +7,7 @@ channels: dependencies: - numpy>=1.24,<2.0 - pytorch>=1.9 + - torchio - torchvision - pytorch-cuda>=11.6 - pip diff --git a/monai/bundle/__init__.py b/monai/bundle/__init__.py index a4a2176f14..3f3c8d545e 100644 --- a/monai/bundle/__init__.py +++ b/monai/bundle/__init__.py @@ -43,4 +43,4 @@ MACRO_KEY, load_bundle_config, ) -from .workflows import BundleWorkflow, ConfigWorkflow +from .workflows import BundleWorkflow, ConfigWorkflow, PythonicWorkflow diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index df69b021e1..b55c62174b 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str """ return self._resolve_one_item(id=id, **kwargs) + def remove_resolved_content(self, id: str) -> Any | None: + """ + Remove the resolved ``ConfigItem`` by id. + + Args: + id: id name of the expected item. + + """ + return self.resolved_content.pop(id) if id in self.resolved_content else None + @classmethod def normalize_id(cls, id: str | int) -> str: """ diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py index 3ecd5dfbc5..75cf7b0b09 100644 --- a/monai/bundle/workflows.py +++ b/monai/bundle/workflows.py @@ -44,12 +44,18 @@ class BundleWorkflow(ABC): workflow_type: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. - default to `train` for train workflow. + default to `None` for only using meta properties. workflow: specifies the workflow type: "train" or "training" for a training workflow, or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. logging_file: config file for `logging` module in the program. for more details: https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. @@ -97,29 +103,50 @@ def __init__( meta_file = None workflow_type = workflow if workflow is not None else workflow_type - if workflow_type is None and properties_path is None: - self.properties = copy(MetaProperties) - self.workflow_type = None - self.meta_file = meta_file - return + if workflow_type is not None: + if workflow_type.lower() in self.supported_train_type: + workflow_type = "train" + elif workflow_type.lower() in self.supported_infer_type: + workflow_type = "infer" + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if properties_path is not None: properties_path = Path(properties_path) if not properties_path.is_file(): raise ValueError(f"Property file {properties_path} does not exist.") with open(properties_path) as json_file: - self.properties = json.load(json_file) - self.workflow_type = None - self.meta_file = meta_file - return - if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr] - self.properties = {**TrainProperties, **MetaProperties} - self.workflow_type = "train" - elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr] - self.properties = {**InferProperties, **MetaProperties} - self.workflow_type = "infer" + try: + properties = json.load(json_file) + self.properties: dict = {} + if workflow_type is not None and workflow_type in properties: + self.properties = properties[workflow_type] + if "meta" in properties: + self.properties.update(properties["meta"]) + elif workflow_type is None: + if "meta" in properties: + self.properties = properties["meta"] + logger.info( + "No workflow type specified, default to load meta properties from property file." + ) + else: + logger.warning("No 'meta' key found in properties while workflow_type is None.") + except KeyError as e: + raise ValueError(f"{workflow_type} not found in property file {properties_path}") from e + except json.JSONDecodeError as e: + raise ValueError(f"Error decoding JSON from property file {properties_path}") from e else: - raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + if workflow_type == "train": + self.properties = {**TrainProperties, **MetaProperties} + elif workflow_type == "infer": + self.properties = {**InferProperties, **MetaProperties} + elif workflow_type is None: + self.properties = copy(MetaProperties) + logger.info("No workflow type and property file specified, default to 'meta' properties.") + else: + raise ValueError(f"Unsupported workflow type: '{workflow_type}'.") + self.workflow_type = workflow_type self.meta_file = meta_file @abstractmethod @@ -226,6 +253,124 @@ def check_properties(self) -> list[str] | None: return [n for n, p in self.properties.items() if p.get(BundleProperty.REQUIRED, False) and not hasattr(self, n)] +class PythonicWorkflow(BundleWorkflow): + """ + Base class for the pythonic workflow specification in bundle, it can be a training, evaluation or inference workflow. + It defines the basic interfaces for the bundle workflow behavior: `initialize`, `finalize`, etc. + This also provides the interface to get / set public properties to interact with a bundle workflow through + defined `get_` accessor methods or directly defining members of the object. + For how to set the properties, users can define the `_set_` methods or directly set the members of the object. + The `initialize` method is called to set up the workflow before running. This method sets up internal state + and prepares properties. If properties are modified after the workflow has been initialized, `self._is_initialized` + is set to `False`. Before running the workflow again, `initialize` should be called to ensure that the workflow is + properly set up with the new property values. + + Args: + workflow_type: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for only using meta properties. + workflow: specifies the workflow type: "train" or "training" for a training workflow, + or "infer", "inference", "eval", "evaluation" for a inference workflow, + other unsupported string will raise a ValueError. + default to `None` for common workflow. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "meta". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. + config_file: path to the config file, typically used to store hyperparameters. + meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order. + logging_file: config file for `logging` module in the program. for more details: + https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig. + + """ + + supported_train_type: tuple = ("train", "training") + supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation") + + def __init__( + self, + workflow_type: str | None = None, + properties_path: PathLike | None = None, + config_file: str | Sequence[str] | None = None, + meta_file: str | Sequence[str] | None = None, + logging_file: str | None = None, + **override: Any, + ): + meta_file = str(Path(os.getcwd()) / "metadata.json") if meta_file is None else meta_file + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, meta_file=meta_file, logging_file=logging_file + ) + self._props_vals: dict = {} + self._set_props_vals: dict = {} + self.parser = ConfigParser() + if config_file is not None: + self.parser.read_config(f=config_file) + if self.meta_file is not None: + self.parser.read_meta(f=self.meta_file) + + # the rest key-values in the _args are to override config content + self.parser.update(pairs=override) + self._is_initialized: bool = False + + def initialize(self, *args: Any, **kwargs: Any) -> Any: + """ + Initialize the bundle workflow before running. + """ + self._props_vals = {} + self._is_initialized = True + + def _get_property(self, name: str, property: dict) -> Any: + """ + With specified property name and information, get the expected property value. + If the property is already generated, return from the bucket directly. + If user explicitly set the property, return it directly. + Otherwise, generate the expected property as a class private property with prefix "_". + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + """ + if not self._is_initialized: + raise RuntimeError("Please execute 'initialize' before getting any properties.") + value = None + if name in self._set_props_vals: + value = self._set_props_vals[name] + elif name in self._props_vals: + value = self._props_vals[name] + elif name in self.parser.config[self.parser.meta_key]: # type: ignore[index] + id = self.properties.get(name, None).get(BundlePropertyConfig.ID, None) + value = self.parser[id] + else: + try: + value = getattr(self, f"get_{name}")() + except AttributeError as e: + if property[BundleProperty.REQUIRED]: + raise ValueError( + f"unsupported property '{name}' is required in the bundle properties," + f"need to implement a method 'get_{name}' to provide the property." + ) from e + self._props_vals[name] = value + return value + + def _set_property(self, name: str, property: dict, value: Any) -> Any: + """ + With specified property name and information, set value for the expected property. + Stores user-reset initialized objects that should not be re-initialized and marks the workflow as not initialized. + + Args: + name: the name of target property. + property: other information for the target property, defined in `TrainProperties` or `InferProperties`. + value: value to set for the property. + + """ + self._set_props_vals[name] = value + self._is_initialized = False + + class ConfigWorkflow(BundleWorkflow): """ Specification for the config-based bundle workflow. @@ -262,7 +407,13 @@ class ConfigWorkflow(BundleWorkflow): or "infer", "inference", "eval", "evaluation" for a inference workflow, other unsupported string will raise a ValueError. default to `None` for common workflow. - properties_path: the path to the JSON file of properties. + properties_path: the path to the JSON file of properties. If `workflow_type` is specified, properties will be + loaded from the file based on the provided `workflow_type` and meta. If no `workflow_type` is specified, + properties will default to loading from "train". If `properties_path` is None, default properties + will be sourced from "monai/bundle/properties.py" based on the workflow_type: + For a training workflow, properties load from `TrainProperties` and `MetaProperties`. + For a inference workflow, properties load from `InferProperties` and `MetaProperties`. + For workflow_type = None : only `MetaProperties` will be loaded. override: id-value pairs to override or add the corresponding config content. e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg`` @@ -324,7 +475,6 @@ def __init__( self.parser.read_config(f=config_file) if self.meta_file is not None: self.parser.read_meta(f=self.meta_file) - # the rest key-values in the _args are to override config content self.parser.update(pairs=override) self.init_id = init_id @@ -394,8 +544,23 @@ def check_properties(self) -> list[str] | None: ret.extend(wrong_props) return ret - def _run_expr(self, id: str, **kwargs: dict) -> Any: - return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None + def _run_expr(self, id: str, **kwargs: dict) -> list[Any]: + """ + Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored, + allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process. + """ + ret = [] + if id in self.parser: + # suppose all the expressions are in a list, run and reset the expressions + if isinstance(self.parser[id], list): + for i in range(len(self.parser[id])): + sub_id = f"{id}{ID_SEP_KEY}{i}" + ret.append(self.parser.get_parsed_content(sub_id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(sub_id) + else: + ret.append(self.parser.get_parsed_content(id, **kwargs)) + self.parser.ref_resolver.remove_resolved_content(id) + return ret def _get_prop_id(self, name: str, property: dict) -> Any: prop_id = property[BundlePropertyConfig.ID] diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py index 3629659db1..0c36da6d3d 100644 --- a/monai/engines/workflow.py +++ b/monai/engines/workflow.py @@ -12,7 +12,7 @@ from __future__ import annotations import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence, Sized from typing import TYPE_CHECKING, Any import torch @@ -121,24 +121,24 @@ def __init__( to_kwargs: dict | None = None, amp_kwargs: dict | None = None, ) -> None: - if iteration_update is not None: - super().__init__(iteration_update) - else: - super().__init__(self._iteration) + super().__init__(self._iteration if iteration_update is None else iteration_update) if isinstance(data_loader, DataLoader): - sampler = data_loader.__dict__["sampler"] + sampler = getattr(data_loader, "sampler", None) + + # set the epoch value for DistributedSampler objects when an epoch starts if isinstance(sampler, DistributedSampler): @self.on(Events.EPOCH_STARTED) def set_sampler_epoch(engine: Engine) -> None: sampler.set_epoch(engine.state.epoch) - if epoch_length is None: + # if the epoch_length isn't given, attempt to get it from the length of the data loader + if epoch_length is None and isinstance(data_loader, Sized): + try: epoch_length = len(data_loader) - else: - if epoch_length is None: - raise ValueError("If data_loader is not PyTorch DataLoader, must specify the epoch_length.") + except TypeError: # raised when data_loader has an iterable dataset with no length, or is some other type + pass # deliberately leave epoch_length as None # set all sharable data for the workflow based on Ignite engine.state self.state: Any = State( @@ -147,7 +147,7 @@ def set_sampler_epoch(engine: Engine) -> None: iteration=0, epoch=0, max_epochs=max_epochs, - epoch_length=epoch_length, + epoch_length=epoch_length, # None when the dataset is iterable and so has no length output=None, batch=None, metrics={}, diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3f02fae6b8..4108820bec 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -23,6 +23,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -39,8 +40,16 @@ class DiceLoss(_Loss): The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of the inter-over-union calculation to smooth results respectively, these values should be small. - The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric - Medical Image Segmentation, 3DV, 2016. + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. """ @@ -58,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -89,6 +99,8 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -114,6 +126,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - - denominator = ground_o + pred_o - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f: torch.Tensor = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -272,6 +279,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -295,6 +303,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -319,6 +329,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - - denominator = ground_o + pred_o + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 4f22bf84b4..154f34c526 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -17,6 +17,7 @@ import torch from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -28,6 +29,9 @@ class TverskyLoss(_Loss): Sadegh et al. (2017) Tversky loss function for image segmentation using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721) + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + Adapted from: https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631 @@ -46,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -70,6 +75,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -93,6 +100,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - p0 = input - p1 = 1 - p0 - g0 = target - g1 = 1 - g0 - # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - tp = torch.sum(p0 * g0, reduce_axis) - fp = self.alpha * torch.sum(p0 * g1, reduce_axis) - fn = self.beta * torch.sum(p1 * g0, reduce_axis) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..782fd9c9c2 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + reduce_axis: the axis to be reduced. + ord: the order of the vector norm. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + decoupled: whether the input and the target should be decoupled when computing fp and fn. + Only for the original implementation when soft_label is False. + + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + # the new implementation that is correct with soft labels + # and it is identical to the original implementation with hard labels + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn diff --git a/monai/networks/blocks/pos_embed_utils.py b/monai/networks/blocks/pos_embed_utils.py index 21586e56da..a9c5176bc2 100644 --- a/monai/networks/blocks/pos_embed_utils.py +++ b/monai/networks/blocks/pos_embed_utils.py @@ -56,7 +56,7 @@ def build_sincos_position_embedding( grid_h = torch.arange(h, dtype=torch.float32) grid_w = torch.arange(w, dtype=torch.float32) - grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij") + grid_h, grid_w = torch.meshgrid(grid_h, grid_w) if embed_dim % 4 != 0: raise AssertionError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding") @@ -75,7 +75,7 @@ def build_sincos_position_embedding( grid_w = torch.arange(w, dtype=torch.float32) grid_d = torch.arange(d, dtype=torch.float32) - grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij") + grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d) if embed_dim % 6 != 0: raise AssertionError("Embed dimension must be divisible by 6 for 3D sin-cos position embedding") diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index ac96b077bd..86e1b1d3ae 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -154,10 +154,12 @@ def __init__( ) self.input_size = input_size - def forward(self, x): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C + attn_mask (torch.Tensor, optional): mask to apply to the attention matrix. + B x (s_dim_1 * ... * s_dim_n). Defaults to None. Return: torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C @@ -176,7 +178,13 @@ def forward(self, x): if self.use_flash_attention: x = F.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=attn_mask, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=self.causal, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale @@ -186,10 +194,16 @@ def forward(self, x): att_mat = self.rel_positional_embedding(x, att_mat, q) if self.causal: + if attn_mask is not None: + raise ValueError("Causal attention does not support attention masks.") att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf")) - att_mat = att_mat.softmax(dim=-1) + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(2) + attn_mask = attn_mask.expand(-1, self.num_heads, -1, -1) + att_mat = att_mat.masked_fill(attn_mask == 0, float("-inf")) + att_mat = att_mat.softmax(dim=-1) if self.save_attn: # no gradients and new tensor; # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 05eb3b07ab..6f0da73e7b 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -90,8 +90,10 @@ def __init__( use_flash_attention=use_flash_attention, ) - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor: - x = x + self.attn(self.norm1(x)) + def forward( + self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: x = x + self.cross_attn(self.norm_cross_attn(x), context=context) x = x + self.mlp(self.norm2(x)) diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index b876e6a3fc..c1917e5293 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -53,6 +53,7 @@ from .generator import Generator from .highresnet import HighResBlock, HighResNet from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet +from .masked_autoencoder_vit import MaskedAutoEncoderViT from .mednext import ( MedNeXt, MedNext, diff --git a/monai/networks/nets/masked_autoencoder_vit.py b/monai/networks/nets/masked_autoencoder_vit.py new file mode 100644 index 0000000000..e76f097346 --- /dev/null +++ b/monai/networks/nets/masked_autoencoder_vit.py @@ -0,0 +1,211 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +import torch +import torch.nn as nn + +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock +from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding +from monai.networks.blocks.transformerblock import TransformerBlock +from monai.networks.layers import trunc_normal_ +from monai.utils import ensure_tuple_rep +from monai.utils.module import look_up_option + +SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"} + +__all__ = ["MaskedAutoEncoderViT"] + + +class MaskedAutoEncoderViT(nn.Module): + """ + Masked Autoencoder (ViT), based on: "Kaiming et al., + Masked Autoencoders Are Scalable Vision Learners " + Only a subset of the patches passes through the encoder. The decoder tries to reconstruct + the masked patches, resulting in improved training speed. + """ + + def __init__( + self, + in_channels: int, + img_size: Sequence[int] | int, + patch_size: Sequence[int] | int, + hidden_size: int = 768, + mlp_dim: int = 512, + num_layers: int = 12, + num_heads: int = 12, + masking_ratio: float = 0.75, + decoder_hidden_size: int = 384, + decoder_mlp_dim: int = 512, + decoder_num_layers: int = 4, + decoder_num_heads: int = 12, + proj_type: str = "conv", + pos_embed_type: str = "sincos", + decoder_pos_embed_type: str = "sincos", + dropout_rate: float = 0.0, + spatial_dims: int = 3, + qkv_bias: bool = False, + save_attn: bool = False, + ) -> None: + """ + Args: + in_channels: dimension of input channels or the number of channels for input. + img_size: dimension of input image. + patch_size: dimension of patch size + hidden_size: dimension of hidden layer. Defaults to 768. + mlp_dim: dimension of feedforward layer. Defaults to 512. + num_layers: number of transformer blocks. Defaults to 12. + num_heads: number of attention heads. Defaults to 12. + masking_ratio: ratio of patches to be masked. Defaults to 0.75. + decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384. + decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512. + decoder_num_layers: number of transformer blocks for decoder. Defaults to 4. + decoder_num_heads: number of attention heads for decoder. Defaults to 12. + proj_type: position embedding layer type. Defaults to "conv". + pos_embed_type: position embedding layer type. Defaults to "sincos". + decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos". + dropout_rate: fraction of the input units to drop. Defaults to 0.0. + spatial_dims: number of spatial dimensions. Defaults to 3. + qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False. + save_attn: to make accessible the attention in self attention block. Defaults to False. + Examples:: + # for single channel input with image size of (96,96,96), and sin-cos positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16), + pos_embed_type='sincos') + # for 3-channel with image size of (128,128,128) and a learnable positional encoding + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable') + # for 3-channel with image size of (224,224) and a masking ratio of 0.25 + >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25, + spatial_dims=2) + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise ValueError(f"dropout_rate should be between 0 and 1, got {dropout_rate}.") + + if hidden_size % num_heads != 0: + raise ValueError("hidden_size should be divisible by num_heads.") + + if decoder_hidden_size % decoder_num_heads != 0: + raise ValueError("decoder_hidden_size should be divisible by decoder_num_heads.") + + self.patch_size = ensure_tuple_rep(patch_size, spatial_dims) + self.img_size = ensure_tuple_rep(img_size, spatial_dims) + self.spatial_dims = spatial_dims + for m, p in zip(self.img_size, self.patch_size): + if m % p != 0: + raise ValueError(f"patch_size={patch_size} should be divisible by img_size={img_size}.") + + self.decoder_hidden_size = decoder_hidden_size + + if masking_ratio <= 0 or masking_ratio >= 1: + raise ValueError(f"masking_ratio should be in the range (0, 1), got {masking_ratio}.") + + self.masking_ratio = masking_ratio + self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) + + self.patch_embedding = PatchEmbeddingBlock( + in_channels=in_channels, + img_size=img_size, + patch_size=patch_size, + hidden_size=hidden_size, + num_heads=num_heads, + proj_type=proj_type, + pos_embed_type=pos_embed_type, + dropout_rate=dropout_rate, + spatial_dims=self.spatial_dims, + ) + blocks = [ + TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(num_layers) + ] + self.blocks = nn.Sequential(*blocks, nn.LayerNorm(hidden_size)) + + # decoder + self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size) + + self.mask_tokens = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size)) + + self.decoder_pos_embed_type = look_up_option(decoder_pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES) + self.decoder_pos_embedding = nn.Parameter(torch.zeros(1, self.patch_embedding.n_patches, decoder_hidden_size)) + + decoder_blocks = [ + TransformerBlock(decoder_hidden_size, decoder_mlp_dim, decoder_num_heads, dropout_rate, qkv_bias, save_attn) + for _ in range(decoder_num_layers) + ] + self.decoder_blocks = nn.Sequential(*decoder_blocks, nn.LayerNorm(decoder_hidden_size)) + self.decoder_pred = nn.Linear(decoder_hidden_size, int(np.prod(self.patch_size)) * in_channels) + + self._init_weights() + + def _init_weights(self): + """ + similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and + classification tokens + """ + if self.decoder_pos_embed_type == "none": + pass + elif self.decoder_pos_embed_type == "learnable": + trunc_normal_(self.decoder_pos_embedding, mean=0.0, std=0.02, a=-2.0, b=2.0) + elif self.decoder_pos_embed_type == "sincos": + grid_size = [] + for in_size, pa_size in zip(self.img_size, self.patch_size): + grid_size.append(in_size // pa_size) + + self.decoder_pos_embedding = build_sincos_position_embedding( + grid_size, self.decoder_hidden_size, self.spatial_dims + ) + + else: + raise ValueError(f"decoder_pos_embed_type {self.decoder_pos_embed_type} not supported.") + + # initialize patch_embedding like nn.Linear (instead of nn.Conv2d) + trunc_normal_(self.mask_tokens, mean=0.0, std=0.02, a=-2.0, b=2.0) + trunc_normal_(self.cls_token, mean=0.0, std=0.02, a=-2.0, b=2.0) + + def _masking(self, x, masking_ratio: float | None = None): + batch_size, num_tokens, _ = x.shape + percentage_to_keep = 1 - masking_ratio if masking_ratio is not None else 1 - self.masking_ratio + selected_indices = torch.multinomial( + torch.ones(batch_size, num_tokens), int(percentage_to_keep * num_tokens), replacement=False + ) + x_masked = x[torch.arange(batch_size).unsqueeze(1), selected_indices] # gather the selected tokens + mask = torch.ones(batch_size, num_tokens, dtype=torch.int).to(x.device) + mask[torch.arange(batch_size).unsqueeze(-1), selected_indices] = 0 + + return x_masked, selected_indices, mask + + def forward(self, x, masking_ratio: float | None = None): + x = self.patch_embedding(x) + x, selected_indices, mask = self._masking(x, masking_ratio=masking_ratio) + + cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + x = self.blocks(x) + + # decoder + x = self.decoder_embed(x) + + x_ = self.mask_tokens.repeat(x.shape[0], mask.shape[1], 1) + x_[torch.arange(x.shape[0]).unsqueeze(-1), selected_indices] = x[:, 1:, :] # no cls token + x_ = x_ + self.decoder_pos_embedding + x = torch.cat([x[:, :1, :], x_], dim=1) + x = self.decoder_blocks(x) + x = self.decoder_pred(x) + + x = x[:, 1:, :] + return x, mask diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 832135ad06..77f0d2ec2f 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -13,7 +13,6 @@ import itertools from collections.abc import Sequence -from typing import Final import numpy as np import torch @@ -51,8 +50,6 @@ class SwinUNETR(nn.Module): " """ - patch_size: Final[int] = 2 - @deprecated_arg( name="img_size", since="1.3", @@ -65,18 +62,24 @@ def __init__( img_size: Sequence[int] | int, in_channels: int, out_channels: int, + patch_size: int = 2, depths: Sequence[int] = (2, 2, 2, 2), num_heads: Sequence[int] = (3, 6, 12, 24), + window_size: Sequence[int] | int = 7, + qkv_bias: bool = True, + mlp_ratio: float = 4.0, feature_size: int = 24, norm_name: tuple | str = "instance", drop_rate: float = 0.0, attn_drop_rate: float = 0.0, dropout_path_rate: float = 0.0, normalize: bool = True, + norm_layer: type[LayerNorm] = nn.LayerNorm, + patch_norm: bool = False, use_checkpoint: bool = False, spatial_dims: int = 3, - downsample="merging", - use_v2=False, + downsample: str | nn.Module = "merging", + use_v2: bool = False, ) -> None: """ Args: @@ -86,14 +89,20 @@ def __init__( It will be removed in an upcoming version. in_channels: dimension of input channels. out_channels: dimension of output channels. + patch_size: size of the patch token. feature_size: dimension of network feature size. depths: number of layers in each stage. num_heads: number of attention heads. + window_size: local window size. + qkv_bias: add a learnable bias to query, key, value. + mlp_ratio: ratio of mlp hidden dim to embedding dim. norm_name: feature normalization type and arguments. drop_rate: dropout rate. attn_drop_rate: attention dropout rate. dropout_path_rate: drop path rate. normalize: normalize output intermediate features in each stage. + norm_layer: normalization layer. + patch_norm: whether to apply normalization to the patch embedding. Default is False. use_checkpoint: use gradient checkpointing for reduced memory usage. spatial_dims: number of spatial dims. downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a @@ -116,13 +125,15 @@ def __init__( super().__init__() - img_size = ensure_tuple_rep(img_size, spatial_dims) - patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) - window_size = ensure_tuple_rep(7, spatial_dims) - if spatial_dims not in (2, 3): raise ValueError("spatial dimension should be 2 or 3.") + self.patch_size = patch_size + + img_size = ensure_tuple_rep(img_size, spatial_dims) + patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims) + window_size = ensure_tuple_rep(window_size, spatial_dims) + self._check_input_size(img_size) if not (0 <= drop_rate <= 1): @@ -146,12 +157,13 @@ def __init__( patch_size=patch_sizes, depths=depths, num_heads=num_heads, - mlp_ratio=4.0, - qkv_bias=True, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=dropout_path_rate, - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, + patch_norm=patch_norm, use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2cdd965c91..d15042181b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -531,6 +531,8 @@ RandIdentity, RandImageFilter, RandLambda, + RandTorchIO, + RandTorchVision, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -540,6 +542,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -620,6 +623,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandTorchIOd, + RandTorchIOD, + RandTorchIODict, RandTorchVisiond, RandTorchVisionD, RandTorchVisionDict, @@ -653,6 +659,9 @@ ToPILd, ToPILD, ToPILDict, + TorchIOd, + TorchIOD, + TorchIODict, TorchVisiond, TorchVisionD, TorchVisionDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 1b3c59afdb..2963c8a2f8 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -18,10 +18,10 @@ import sys import time import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Union import numpy as np import torch @@ -99,11 +99,14 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "TorchIO", "MapLabelValue", "IntensityStats", "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1051,12 +1054,11 @@ def __call__( class ConvertToMultiChannelBasedOnBratsClasses(Transform): """ - Convert labels to multi channels based on brats18 classes: - label 1 is the necrotic and non-enhancing tumor core - label 2 is the peritumoral edema - label 4 is the GD-enhancing tumor - The possible classes are TC (Tumor core), WT (Whole tumor) - and ET (Enhancing tumor). + Convert labels to multi channels based on `brats18 `_ classes, + which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor): + label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion, + label 2 is the peritumoral edema, which is counted only under WT subregion, + label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -1136,12 +1138,44 @@ def __call__( return concatenate((img, points_image), axis=0) -class TorchVision: +class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class RandTorchVision(Transform, RandomizableTrait): + """ + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ backend = [TransformBackends.TORCH] @@ -1172,6 +1206,68 @@ def __call__(self, img: NdarrayOrTensor): return out +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + + """ + return self.trans(img) + + class MapLabelValue: """ Utility to map label values to another set of values. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 65c721e48e..7dd2397a74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -60,6 +60,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -136,6 +137,9 @@ "RandLambdaD", "RandLambdaDict", "RandLambdad", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandTorchVisionD", "RandTorchVisionDict", "RandTorchVisiond", @@ -172,6 +176,9 @@ "ToTensorD", "ToTensorDict", "ToTensord", + "TorchIOD", + "TorchIODict", + "TorchIOd", "TorchVisionD", "TorchVisionDict", "TorchVisiond", @@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class TorchIOd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) + + class MapLabelValued(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. @@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesd ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld +TorchIOD = TorchIODict = TorchIOd TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond +RandTorchIOD = RandTorchIODict = RandTorchIOd RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd diff --git a/monai/utils/module.py b/monai/utils/module.py index 1ad001fc87..d3f2ff09f2 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -649,7 +649,7 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s current_ver_string: if None, the current system GPU CUDA compute capability will be used. Returns: - True if the current system GPU CUDA compute capability is greater than the specified version. + True if the current system GPU CUDA compute capability is greater than or equal to the specified version. """ if current_ver_string is None: cuda_available = torch.cuda.is_available() @@ -667,11 +667,11 @@ def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: s ver, has_ver = optional_import("packaging.version", name="parse") if has_ver: - return ver(".".join((f"{major}", f"{minor}"))) < ver(f"{current_ver_string}") # type: ignore + return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2) while len(parts) < 2: parts += ["0"] c_major, c_minor = parts[:2] c_mn = int(c_major), int(c_minor) mn = int(major), int(minor) - return c_mn >= mn + return c_mn > mn diff --git a/requirements-dev.txt b/requirements-dev.txt index 72654d3534..bffe304df4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows" types-setuptools mypy>=1.5.0, <1.12.0 ninja +torchio torchvision psutil cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10" diff --git a/setup.cfg b/setup.cfg index 694dc969d9..0c69051218 100644 --- a/setup.cfg +++ b/setup.cfg @@ -55,15 +55,16 @@ all = tensorboard gdown>=4.7.3 pytorch-ignite==0.4.11 + torchio torchvision itk>=5.2 tqdm>=4.47.0 lmdb psutil - cucim-cu12; python_version >= '3.9' and python_version <= '3.10' + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide-python - tifffile - imagecodecs + tifffile; platform_system == "Linux" or platform_system == "Darwin" + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas einops transformers>=4.36.0, <4.41.0; python_version <= '3.10' @@ -77,7 +78,7 @@ all = pynrrd pydicom h5py - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna onnx>=1.13.0 onnxruntime; python_version <= '3.10' @@ -102,6 +103,8 @@ gdown = gdown>=4.7.3 ignite = pytorch-ignite==0.4.11 +torchio = + torchio torchvision = torchvision itk = @@ -113,13 +116,13 @@ lmdb = psutil = psutil cucim = - cucim-cu12 + cucim-cu12; platform_system == "Linux" and python_version >= '3.9' and python_version <= '3.10' openslide = openslide-python tifffile = - tifffile + tifffile; platform_system == "Linux" or platform_system == "Darwin" imagecodecs = - imagecodecs + imagecodecs; platform_system == "Linux" or platform_system == "Darwin" pandas = pandas einops = @@ -149,7 +152,7 @@ pydicom = h5py = h5py nni = - nni + nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine optuna = optuna onnx = diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py index b2c44c12c6..fcfc5b2951 100644 --- a/tests/nonconfig_workflow.py +++ b/tests/nonconfig_workflow.py @@ -13,7 +13,7 @@ import torch -from monai.bundle import BundleWorkflow +from monai.bundle import BundleWorkflow, PythonicWorkflow from monai.data import DataLoader, Dataset from monai.engines import SupervisedEvaluator from monai.inferers import SlidingWindowInferer @@ -26,8 +26,9 @@ LoadImaged, SaveImaged, ScaleIntensityd, + ScaleIntensityRanged, ) -from monai.utils import BundleProperty, set_determinism +from monai.utils import BundleProperty, CommonKeys, set_determinism class NonConfigWorkflow(BundleWorkflow): @@ -176,3 +177,62 @@ def _set_property(self, name, property, value): self._numpy_version = value elif property[BundleProperty.REQUIRED]: raise ValueError(f"unsupported property '{name}' is required in the bundle properties.") + + +class PythonicWorkflowImpl(PythonicWorkflow): + """ + Test class simulates the bundle workflow defined by Python script directly. + """ + + def __init__( + self, + workflow_type: str = "inference", + config_file: str | None = None, + properties_path: str | None = None, + meta_file: str | None = None, + ): + super().__init__( + workflow_type=workflow_type, properties_path=properties_path, config_file=config_file, meta_file=meta_file + ) + self.dataflow: dict = {} + + def initialize(self): + self._props_vals = {} + self._is_initialized = True + self.net = UNet( + spatial_dims=3, + in_channels=1, + out_channels=2, + channels=(16, 32, 64, 128), + strides=(2, 2, 2), + num_res_units=2, + ).to(self.device) + preprocessing = Compose( + [ + EnsureChannelFirstd(keys=["image"]), + ScaleIntensityd(keys="image"), + ScaleIntensityRanged(keys="image", a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True), + ] + ) + self.dataset = Dataset(data=[self.dataflow], transform=preprocessing) + self.postprocessing = Compose([Activationsd(keys="pred", softmax=True), AsDiscreted(keys="pred", argmax=True)]) + + def run(self): + data = self.dataset[0] + inputs = data[CommonKeys.IMAGE].unsqueeze(0).to(self.device) + self.net.eval() + with torch.no_grad(): + data[CommonKeys.PRED] = self.inferer(inputs, self.net) + self.dataflow.update({CommonKeys.PRED: self.postprocessing(data)[CommonKeys.PRED]}) + + def finalize(self): + pass + + def get_bundle_root(self): + return "." + + def get_device(self): + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def get_inferer(self): + return SlidingWindowInferer(roi_size=self.parser.roi_size, sw_batch_size=1, overlap=0) diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py index 835c8e5c1d..27e1ee97a8 100644 --- a/tests/test_bundle_trt_export.py +++ b/tests/test_bundle_trt_export.py @@ -53,7 +53,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTExport(unittest.TestCase): def setUp(self): diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py index 1727fcdf53..893b9dc991 100644 --- a/tests/test_bundle_workflow.py +++ b/tests/test_bundle_workflow.py @@ -13,6 +13,7 @@ import os import shutil +import sys import tempfile import unittest from copy import deepcopy @@ -22,12 +23,12 @@ import torch from parameterized import parameterized -from monai.bundle import ConfigWorkflow +from monai.bundle import ConfigWorkflow, create_workflow from monai.data import Dataset from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.networks.nets import UNet -from monai.transforms import Compose, LoadImage -from tests.nonconfig_workflow import NonConfigWorkflow +from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged +from tests.nonconfig_workflow import NonConfigWorkflow, PythonicWorkflowImpl TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")] @@ -35,6 +36,8 @@ TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")] +TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")] + TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."] @@ -45,7 +48,9 @@ def setUp(self): self.expected_shape = (128, 128, 128) test_image = np.random.rand(*self.expected_shape) self.filename = os.path.join(self.data_dir, "image.nii") + self.filename1 = os.path.join(self.data_dir, "image1.nii") nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename) + nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1) def tearDown(self): shutil.rmtree(self.data_dir) @@ -108,12 +113,42 @@ def test_inference_config(self, config_file): # test property path inferer = ConfigWorkflow( config_file=config_file, + workflow_type="infer", properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"), logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), **override, ) self._test_inferer(inferer) - self.assertEqual(inferer.workflow_type, None) + self.assertEqual(inferer.workflow_type, "infer") + + @parameterized.expand([TEST_CASE_4]) + def test_responsive_inference_config(self, config_file): + input_loader = LoadImaged(keys="image") + output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg") + + # test standard MONAI model-zoo config workflow + inferer = ConfigWorkflow( + workflow_type="infer", + config_file=config_file, + logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"), + ) + # FIXME: temp add the property for test, we should add it to some formal realtime infer properties + inferer.add_property(name="dataflow", required=True, config_id="dataflow") + + inferer.initialize() + inferer.dataflow.update(input_loader({"image": self.filename})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz"))) + + # bundle is instantiated and idle, just change the input for next inference + inferer.dataflow.clear() + inferer.dataflow.update(input_loader({"image": self.filename1})) + inferer.run() + output_saver(inferer.dataflow) + self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz"))) + + inferer.finalize() @parameterized.expand([TEST_CASE_3]) def test_train_config(self, config_file): @@ -164,6 +199,72 @@ def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_erro with self.assertRaisesRegex(FileNotFoundError, expected_error): NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file) + def test_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + workflow = PythonicWorkflowImpl( + workflow_type="infer", config_file=config_file, meta_file=meta_file, properties_path=property_path + ) + workflow.initialize() + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + workflow.finalize() + + def test_create_pythonic_workflow(self): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + config_file = {"roi_size": (64, 64, 32)} + meta_file = os.path.join(os.path.dirname(__file__), "testing_data", "metadata.json") + property_path = os.path.join(os.path.dirname(__file__), "testing_data", "python_workflow_properties.json") + sys.path.append(os.path.dirname(__file__)) + workflow = create_workflow( + "nonconfig_workflow.PythonicWorkflowImpl", + workflow_type="infer", + config_file=config_file, + meta_file=meta_file, + properties_path=property_path, + ) + # Load input data + input_loader = LoadImaged(keys="image") + workflow.dataflow.update(input_loader({"image": self.filename})) + self.assertEqual(workflow.bundle_root, ".") + self.assertEqual(workflow.device, device) + self.assertEqual(workflow.version, "0.1.0") + # check config override correctly + self.assertEqual(workflow.inferer.roi_size, (64, 64, 32)) + + # check set property override correctly + workflow.inferer = SlidingWindowInferer(roi_size=config_file["roi_size"], sw_batch_size=1, overlap=0.5) + workflow.initialize() + self.assertEqual(workflow.inferer.overlap, 0.5) + + workflow.run() + # update input data and run again + workflow.dataflow.update(input_loader({"image": self.filename1})) + workflow.run() + pred = workflow.dataflow["pred"] + self.assertEqual(pred.shape[2:], self.expected_shape) + self.assertEqual(pred.meta["filename_or_obj"], self.filename1) + + # test add properties + workflow.add_property(name="net", required=True, desc="network for the training.") + self.assertIn("net", workflow.properties) + workflow.finalize() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py index 712d887c3b..a7b1edec3c 100644 --- a/tests/test_convert_to_trt.py +++ b/tests/test_convert_to_trt.py @@ -38,7 +38,7 @@ @skip_if_windows @skip_if_no_cuda @skip_if_quick -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestConvertToTRT(unittest.TestCase): def setUp(self): diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 14aa6ec241..cea6ccf113 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 5738f4a089..9706c2e746 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416597, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307748, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index cfa711e4c0..fb554e391c 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -18,8 +18,10 @@ import nibabel as nib import numpy as np +import torch.nn as nn from monai.data import DataLoader, Dataset, IterableDataset +from monai.engines import SupervisedEvaluator from monai.transforms import Compose, LoadImaged, SimulateDelayd @@ -59,6 +61,17 @@ def test_shape(self): for d in dataloader: self.assertTupleEqual(d["image"].shape[1:], expected_shape) + def test_supervisedevaluator(self): + """ + Test that a SupervisedEvaluator is compatible with IterableDataset in conjunction with DataLoader. + """ + data = list(range(10)) + dl = DataLoader(IterableDataset(data)) + evaluator = SupervisedEvaluator(device="cpu", val_data_loader=dl, network=nn.Identity()) + evaluator.run() # fails if the epoch length or other internal setup is not done correctly + + self.assertEqual(evaluator.state.iteration, len(data)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_masked_autoencoder_vit.py b/tests/test_masked_autoencoder_vit.py new file mode 100644 index 0000000000..f8f6977cc2 --- /dev/null +++ b/tests/test_masked_autoencoder_vit.py @@ -0,0 +1,160 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT +from tests.utils import skip_if_quick + +TEST_CASE_MaskedAutoEncoderViT = [] +for masking_ratio in [0.5]: + for dropout_rate in [0.6]: + for in_channels in [4]: + for hidden_size in [768]: + for img_size in [96, 128]: + for patch_size in [16]: + for num_heads in [12]: + for mlp_dim in [3072]: + for num_layers in [4]: + for decoder_hidden_size in [384]: + for decoder_mlp_dim in [512]: + for decoder_num_layers in [4]: + for decoder_num_heads in [16]: + for pos_embed_type in ["sincos", "learnable"]: + for proj_type in ["conv", "perceptron"]: + for nd in (2, 3): + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * nd, + "patch_size": (patch_size,) * nd, + "hidden_size": hidden_size, + "mlp_dim": mlp_dim, + "num_layers": num_layers, + "decoder_hidden_size": decoder_hidden_size, + "decoder_mlp_dim": decoder_mlp_dim, + "decoder_num_layers": decoder_num_layers, + "decoder_num_heads": decoder_num_heads, + "pos_embed_type": pos_embed_type, + "masking_ratio": masking_ratio, + "decoder_pos_embed_type": pos_embed_type, + "num_heads": num_heads, + "proj_type": proj_type, + "dropout_rate": dropout_rate, + }, + (2, in_channels, *([img_size] * nd)), + ( + 2, + (img_size // patch_size) ** nd, + in_channels * (patch_size**nd), + ), + ] + if nd == 2: + test_case[0]["spatial_dims"] = 2 # type: ignore + TEST_CASE_MaskedAutoEncoderViT.append(test_case) + +TEST_CASE_ill_args = [ + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "decoder_pos_embed_type": "sin"}], + [{"in_channels": 1, "img_size": (32, 32, 32), "patch_size": (64, 64, 64)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "num_layers": 12, "num_heads": 14}], + [{"in_channels": 1, "img_size": (97, 97, 97), "patch_size": (16, 16, 16)}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": 1.1}], + [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (64, 64, 64), "masking_ratio": -0.1}], +] + + +@skip_if_quick +class TestMaskedAutoencoderViT(unittest.TestCase): + + @parameterized.expand(TEST_CASE_MaskedAutoEncoderViT) + def test_shape(self, input_param, input_shape, expected_shape): + net = MaskedAutoEncoderViT(**input_param) + with eval_mode(net): + result, _ = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + + def test_frozen_pos_embedding(self): + net = MaskedAutoEncoderViT(in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16)) + + self.assertEqual(net.decoder_pos_embedding.requires_grad, False) + + @parameterized.expand(TEST_CASE_ill_args) + def test_ill_arg(self, input_param): + with self.assertRaises(ValueError): + MaskedAutoEncoderViT(**input_param) + + def test_access_attn_matrix(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # no data in the matrix + no_matrix_acess_blk = MaskedAutoEncoderViT(in_channels=in_channels, img_size=img_size, patch_size=patch_size) + no_matrix_acess_blk(torch.randn(in_shape)) + assert isinstance(no_matrix_acess_blk.blocks[0].attn.att_mat, torch.Tensor) + # no of elements is zero + assert no_matrix_acess_blk.blocks[0].attn.att_mat.nelement() == 0 + + # be able to acess the attention matrix + matrix_acess_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, save_attn=True + ) + matrix_acess_blk(torch.randn(in_shape)) + + assert matrix_acess_blk.blocks[0].attn.att_mat.shape == (in_shape[0], 12, 55, 55) + + def test_masking_ratio(self): + # input format + in_channels = 1 + img_size = (96, 96, 96) + patch_size = (16, 16, 16) + in_shape = (1, in_channels, img_size[0], img_size[1], img_size[2]) + + # masking ratio 0.25 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.25, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.25) + ) + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + # masking ratio 0.33 + masking_ratio_blk = MaskedAutoEncoderViT( + in_channels=in_channels, img_size=img_size, patch_size=patch_size, masking_ratio=0.33, save_attn=True + ) + masking_ratio_blk(torch.randn(in_shape)) + desired_num_tokens = int( + (img_size[0] // patch_size[0]) + * (img_size[1] // patch_size[1]) + * (img_size[2] // patch_size[2]) + * (1 - 0.33) + ) + + assert masking_ratio_blk.blocks[0].attn.att_mat.shape[-1] - 1 == desired_num_tokens + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_module_list.py b/tests/test_module_list.py index d21ba53b7c..833441cbca 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -58,13 +58,17 @@ def test_transform_api(self): continue with self.subTest(n=n): basename = n[:-1] # Transformd basename is Transform + + # remove aliases to check, do this before the assert below so that a failed assert does skip this + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + for docname in (f"{basename}", f"{basename}d"): if docname in to_exclude_docs: continue if (contents is not None) and f"`{docname}`" not in f"{contents}": self.assertTrue(False, f"please add `{docname}` to docs/source/transforms.rst") - for postfix in ("D", "d", "Dict"): - remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) diff --git a/tests/test_rand_torchio.py b/tests/test_rand_torchio.py new file mode 100644 index 0000000000..ab212d4a11 --- /dev/null +++ b/tests/test_rand_torchio.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIO +from monai.utils import optional_import, set_determinism + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [ + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_torchiod.py b/tests/test_rand_torchiod.py new file mode 100644 index 0000000000..52bcf7c576 --- /dev/null +++ b/tests/test_rand_torchiod.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIOd +from monai.utils import optional_import, set_determinism +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIOd(**input_param)(input_data) + self.assertFalse(np.allclose(input_data["img1"], result["img1"], atol=1e-6, rtol=1e-6)) + assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 88919fd8b1..338f1bf840 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -122,6 +122,24 @@ def test_causal(self): # check upper triangular part of the attention matrix is zero assert torch.triu(block.att_mat, diagonal=1).sum() == 0 + def test_masked_selfattention(self): + n = 64 + block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, sequence_length=16, save_attn=True) + input_shape = (1, n, 128) + # generate a mask randomly with zeros and ones of shape (1, n) + mask = torch.randint(0, 2, (1, n)).bool() + block(torch.randn(input_shape), attn_mask=mask) + att_mat = block.att_mat.squeeze() + # ensure all masked columns are zeros + assert torch.allclose(att_mat[:, ~mask.squeeze(0)], torch.zeros_like(att_mat[:, ~mask.squeeze(0)])) + + def test_causal_and_mask(self): + with self.assertRaises(ValueError): + block = SABlock(hidden_size=128, num_heads=1, causal=True, sequence_length=64) + inputs = torch.randn(2, 64, 128) + mask = torch.randint(0, 2, (2, 64)).bool() + block(inputs, attn_mask=mask) + @skipUnless(has_einops, "Requires einops") def test_access_attn_matrix(self): # input format diff --git a/tests/test_torchio.py b/tests/test_torchio.py new file mode 100644 index 0000000000..d2d598ca4c --- /dev/null +++ b/tests/test_torchio.py @@ -0,0 +1,41 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import TorchIO +from monai.utils import optional_import + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + result = TorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py new file mode 100644 index 0000000000..892287461c --- /dev/null +++ b/tests/test_torchiod.py @@ -0,0 +1,47 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.transforms import TorchIOd +from monai.utils import optional_import +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [ + [ + {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, + {"img": TEST_TENSOR}, + ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, + ] +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_value(self, input_param, input_data, expected_value): + result = TorchIOd(**input_param)(input_data) + assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py index e1323c201f..f7779fec9b 100644 --- a/tests/test_trt_compile.py +++ b/tests/test_trt_compile.py @@ -50,7 +50,7 @@ def forward(self, x: list[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: f @skip_if_quick @unittest.skipUnless(trt_imported, "tensorrt is required") @unittest.skipUnless(polygraphy_imported, "polygraphy is required") -@SkipIfBeforeComputeCapabilityVersion((7, 0)) +@SkipIfBeforeComputeCapabilityVersion((7, 5)) class TestTRTCompile(unittest.TestCase): def setUp(self): diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0365503ea2..73a841a55d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_version_after.py b/tests/test_version_after.py index 34a5054974..b6cb741382 100644 --- a/tests/test_version_after.py +++ b/tests/test_version_after.py @@ -38,7 +38,7 @@ TEST_CASES_SM = [ # (major, minor, sm, expected) - (6, 1, "6.1", False), + (6, 1, "6.1", True), (6, 1, "6.0", False), (6, 0, "8.6", True), (7, 0, "8", True), diff --git a/tests/testing_data/fl_infer_properties.json b/tests/testing_data/fl_infer_properties.json index 72e97cd2c6..6b40edd2ab 100644 --- a/tests/testing_data/fl_infer_properties.json +++ b/tests/testing_data/fl_infer_properties.json @@ -1,67 +1,76 @@ { - "bundle_root": { - "description": "root path of the bundle.", - "required": true, - "id": "bundle_root" + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "dataset_dir": { + "description": "directory path of the dataset.", + "required": true, + "id": "dataset_dir" + }, + "dataset": { + "description": "PyTorch dataset object for the inference / evaluation logic.", + "required": true, + "id": "dataset" + }, + "evaluator": { + "description": "inference / evaluation workflow engine.", + "required": true, + "id": "evaluator" + }, + "network_def": { + "description": "network module for the inference.", + "required": true, + "id": "network_def" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + }, + "dataset_data": { + "description": "data source for the inference / evaluation dataset.", + "required": false, + "id": "dataset::data", + "refer_id": null + }, + "handlers": { + "description": "event-handlers for the inference / evaluation logic.", + "required": false, + "id": "handlers", + "refer_id": "evaluator::val_handlers" + }, + "preprocessing": { + "description": "preprocessing for the input data.", + "required": false, + "id": "preprocessing", + "refer_id": "dataset::transform" + }, + "postprocessing": { + "description": "postprocessing for the model output data.", + "required": false, + "id": "postprocessing", + "refer_id": "evaluator::postprocessing" + }, + "key_metric": { + "description": "the key metric during evaluation.", + "required": false, + "id": "key_metric", + "refer_id": "evaluator::key_val_metric" + } }, - "device": { - "description": "target device to execute the bundle workflow.", - "required": true, - "id": "device" - }, - "dataset_dir": { - "description": "directory path of the dataset.", - "required": true, - "id": "dataset_dir" - }, - "dataset": { - "description": "PyTorch dataset object for the inference / evaluation logic.", - "required": true, - "id": "dataset" - }, - "evaluator": { - "description": "inference / evaluation workflow engine.", - "required": true, - "id": "evaluator" - }, - "network_def": { - "description": "network module for the inference.", - "required": true, - "id": "network_def" - }, - "inferer": { - "description": "MONAI Inferer object to execute the model computation in inference.", - "required": true, - "id": "inferer" - }, - "dataset_data": { - "description": "data source for the inference / evaluation dataset.", - "required": false, - "id": "dataset::data", - "refer_id": null - }, - "handlers": { - "description": "event-handlers for the inference / evaluation logic.", - "required": false, - "id": "handlers", - "refer_id": "evaluator::val_handlers" - }, - "preprocessing": { - "description": "preprocessing for the input data.", - "required": false, - "id": "preprocessing", - "refer_id": "dataset::transform" - }, - "postprocessing": { - "description": "postprocessing for the model output data.", - "required": false, - "id": "postprocessing", - "refer_id": "evaluator::postprocessing" - }, - "key_metric": { - "description": "the key metric during evaluation.", - "required": false, - "id": "key_metric", - "refer_id": "evaluator::key_val_metric" + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } } } diff --git a/tests/testing_data/python_workflow_properties.json b/tests/testing_data/python_workflow_properties.json new file mode 100644 index 0000000000..cd4295839a --- /dev/null +++ b/tests/testing_data/python_workflow_properties.json @@ -0,0 +1,26 @@ +{ + "infer": { + "bundle_root": { + "description": "root path of the bundle.", + "required": true, + "id": "bundle_root" + }, + "device": { + "description": "target device to execute the bundle workflow.", + "required": true, + "id": "device" + }, + "inferer": { + "description": "MONAI Inferer object to execute the model computation in inference.", + "required": true, + "id": "inferer" + } + }, + "meta": { + "version": { + "description": "version of the inference configuration.", + "required": true, + "id": "_meta_::version" + } + } +} diff --git a/tests/testing_data/responsive_inference.json b/tests/testing_data/responsive_inference.json new file mode 100644 index 0000000000..16d953d38e --- /dev/null +++ b/tests/testing_data/responsive_inference.json @@ -0,0 +1,101 @@ +{ + "imports": [ + "$from collections import defaultdict" + ], + "bundle_root": "will override", + "device": "$torch.device('cpu')", + "network_def": { + "_target_": "UNet", + "spatial_dims": 3, + "in_channels": 1, + "out_channels": 2, + "channels": [ + 2, + 2, + 4, + 8, + 4 + ], + "strides": [ + 2, + 2, + 2, + 2 + ], + "num_res_units": 2, + "norm": "batch" + }, + "network": "$@network_def.to(@device)", + "dataflow": "$defaultdict()", + "preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "EnsureChannelFirstd", + "keys": "image" + }, + { + "_target_": "ScaleIntensityd", + "keys": "image" + }, + { + "_target_": "RandRotated", + "_disabled_": true, + "keys": "image" + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": [ + "@dataflow" + ], + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 0 + }, + "inferer": { + "_target_": "SlidingWindowInferer", + "roi_size": [ + 64, + 64, + 32 + ], + "sw_batch_size": 4, + "overlap": 0.25 + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "Activationsd", + "keys": "pred", + "softmax": true + }, + { + "_target_": "AsDiscreted", + "keys": "pred", + "argmax": true + } + ] + }, + "evaluator": { + "_target_": "SupervisedEvaluator", + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "amp": false, + "epoch_length": 1 + }, + "run": [ + "$@evaluator.run()", + "$@dataflow.update(@evaluator.state.output[0])" + ] +}