Skip to content

Latest commit

 

History

History
171 lines (137 loc) · 9.96 KB

algorithm_generation.md

File metadata and controls

171 lines (137 loc) · 9.96 KB

Algorithm Generation

The module of algorithm generation is to create self-contained algorithm folders for further model training, inference, and validation with various neural network architectures and training recipes. This module takes input configuration ".yaml" files (shown below), dataset summaries (e.g. "data_stats.yaml") from our provided data analysis tools, and algorithm templates. And it outputs different algorithm folders under cross-validation. In the default design, the generated algorithm folders follow the designs of the MONAI bundle. Users can run model training, inference, and validation inside those self-contained folders.

modality: CT
datalist: "./task.json"
dataroot: "/workspace/data/task"

The input configuration files and dataset summaries are critical for algorithm generation. For example, the data modality is important for intensity normalization strategies, average image shape determines image region-of-interest (ROI) cropping, and input/output channels decide the first/last layers of the network.

Algorithms

The default algorithms are based on three different networks, DiNTS, (2D/3D) SegResNest, and SwinUNETR, with their well-tuned training recipes.

Algorithm DiNTS 2D SegResNet SegResNet SwinUNETR
Network Densely-connected lattice-based network U-shape network architecture with 2D residual blocks U-shape network architecture with 3D residual blocks U-shape network architecture with Swin-Transformer based encoder
Training Recipes Model Input:
- 96 x 96 x 96 for training
- 96 x 96 x 96 for inference
AMP: True
Optimizer: SGD
Initial learning Rate: 0.2
Loss: DiceFocalLoss
Model Input:
- 320 x 320 for training
- 320 x 320 for inference
AMP: True
Optimizer: SGD
Initial learning Rate: 0.2
Loss: DiceFocalLoss
Model Input:
- 224 x 224 x 144 for training
- 224 x 224 x 144 for inference
AMP: True
Optimizer: AdamW
Initial learning Rate: 0.0002
Loss: DiceCELoss
Model Input:
- 96 x 96 x 96 for training
- 96 x 96 x 96 for inference
AMP: True
Optimizer: AdamW
Initial learning Rate: 0.0001
Loss: DiceCELoss
Transforms - Intensity Normalization
- Random ROI cropping
- Random rotation
- Random zoom
- Random Gaussian smoothing
- Random intensity scaling
- Random intensity shifting
- Random Gaussian noising
- Random flipping
- Intensity Normalization
- Random ROI cropping
- Random rotation
- Random zoom
- Random Gaussian smoothing
- Random intensity scaling
- Random intensity shifting
- Random Gaussian noising
- Random flipping
- Intensity Normalization
- Random ROI cropping
- Random affine transformation
- Random Gaussian smoothing
- Random intensity scaling
- Random intensity shifting
- Random Gaussian noising
- Random flipping
- Intensity Normalization
- Random ROI cropping
- Random rotation
- Random intensity shifting
- Random flipping

For model inference, we use a sliding-window scheme to generate probability maps for output classes/channels through a softmax/sigmoid layer. The overlap for sliding window inference is more than 25% of the window size. The probability map is re-sampled back to its original spacing through each class channel. Next, a segmentation mask is generated using the argmax or thresholding operation on the channel dimension (with or without model ensemble) and saved with the original affine matrix.

Python Command

The following Python script shows how to generate algorithm bundles using the Python class BundleGen.

## algorithm generation
import os
from monai.apps.auto3dseg import BundleGen

work_dir = "./work_dir"
data_output_yaml = os.path.join(work_dir, "data_stats.yaml")
data_src_cfg = "./task.yaml"

bundle_generator = BundleGen(
    algo_path=work_dir,
    data_stats_filename=data_output_yaml,
    data_src_cfg_name=data_src_cfg,
)

bundle_generator.generate(work_dir, num_fold=5)

The code block would generate multiple algorithm bundles as follows. The folder name suffix indicates the ith fold of N-fold cross-validation.

./workdir/
├── dints_0
├── dints_1
...
├── dints_4
├── segresnet_0
...
├── segresnet_4
├── segresnet2d_0
...

Algorithm Templates

The Python class BundleGen utilizes the default algorithm templates implicitly. The default algorithms are based on four established works (DiNTS, SegResNet, SegResNet2D, and SwinUNETR). They support both 3D CT and MR image segmentation. In the template, some items are empty or null, and they will be filled together with dataset information. The part of the configuration file "hyper_parameters.yaml" is shown below. In the configuration, the items (like "bundle_root", "data_file_base_dir", and "patch_size") will be filled up automatically without any user interaction.

bundle_root: null
ckpt_path: "$@bundle_root + '/model_fold' + str(@training#fold)"
data_file_base_dir: null
data_list_file_path: null

training:
  # hyper-parameters
  amp: true
  determ: false
  fold: 0
  input_channels: null
  learning_rate: 0.2
  num_images_per_batch: 2
  num_iterations: 40000
  num_iterations_per_validation: 500
  num_patches_per_image: 1
  num_sw_batch_size: 2
  output_classes: null
  overlap_ratio: 0.625
  patch_size: null
  patch_size_valid: null
  softmax: true

  loss:
    _target_: DiceFocalLoss
    include_background: true
...

The actual template filling is done using the "fill_template_config" function in the "Algo" class of the script "scripts/algo.py". The "algo.py" of different algorithms is located inside their bundle templates.

class DintsAlgo(BundleAlgo):
    def fill_template_config(self, data_stats_file, output_path, **kwargs):
			...
            patch_size = [128, 128, 96]
            max_shape = data_stats["stats_summary#image_stats#shape#max"]
            patch_size = [
                max(32, shape_k // 32 * 32) if shape_k < p_k else p_k for p_k, shape_k in zip(patch_size, max_shape)
            ]

            input_channels = data_stats["stats_summary#image_stats#channels#max"]
            output_classes = len(data_stats["stats_summary#label_stats#labels"])

            hyper_parameters.update({"data_file_base_dir": os.path.abspath(data_src_cfg["dataroot"])})
            hyper_parameters.update({"data_list_file_path": os.path.abspath(data_src_cfg["datalist"])})

            hyper_parameters.update({"training#patch_size": patch_size})
            hyper_parameters.update({"training#patch_size_valid": patch_size})
            hyper_parameters.update({"training#input_channels": input_channels})
            hyper_parameters.update({"training#output_classes": output_classes})

            hyper_parameters_search.update({"searching#patch_size": patch_size})
            hyper_parameters_search.update({"searching#patch_size_valid": patch_size})
            hyper_parameters_search.update({"searching#input_channels": input_channels})
            hyper_parameters_search.update({"searching#output_classes": output_classes})

            modality = data_src_cfg.get("modality", "ct").lower()
            spacing = data_stats["stats_summary#image_stats#spacing#median"]

            intensity_upper_bound = float(data_stats["stats_summary#image_foreground_stats#intensity#percentile_99_5"])
            intensity_lower_bound = float(data_stats["stats_summary#image_foreground_stats#intensity#percentile_00_5"])

            ct_intensity_xform = {
                "_target_": "Compose",
                "transforms": [
                    {
                        "_target_": "ScaleIntensityRanged",
                        "keys": "@image_key",
                        "a_min": intensity_lower_bound,
                        "a_max": intensity_upper_bound,
                        "b_min": 0.0,
                        "b_max": 1.0,
                        "clip": True,
                    },
                    {"_target_": "CropForegroundd", "keys": ["@image_key", "@label_key"], "source_key": "@image_key"},
                ],
            }

            mr_intensity_transform = {
                "_target_": "NormalizeIntensityd",
                "keys": "@image_key",
                "nonzero": True,
                "channel_wise": True,
            }

            transforms_train.update({'transforms_train#transforms#3#pixdim': spacing})
            transforms_validate.update({'transforms_validate#transforms#3#pixdim': spacing})
            transforms_infer.update({'transforms_infer#transforms#3#pixdim': spacing})

            if modality.startswith("ct"):
                transforms_train.update({'transforms_train#transforms#5': ct_intensity_xform})
                transforms_validate.update({'transforms_validate#transforms#5': ct_intensity_xform})
                transforms_infer.update({'transforms_infer#transforms#5': ct_intensity_xform})
            else:
                transforms_train.update({'transforms_train#transforms#5': mr_intensity_transform})
                transforms_validate.update({'transforms_validate#transforms#5': mr_intensity_transform})
                transforms_infer.update({'transforms_infer#transforms#5': mr_intensity_transform})
			...
        return fill_records