Note: This repo will be continually updated upon future advancements and we welcome open-source contributions! Currently, it shares the TensorFlow 2.5 version of the Hierarchical Probabilistic 3D U-Net (with attention mechanisms, nested decoder structure and deep supervision), titled M1
, as explored in the publication(s) listed below. Source code used for training this model, as per our original setup, carry a large number of dependencies on internal datasets, tooling, infrastructure and hardware, and their release is currently not feasible. However, an equivalent minimal adaptation has been made available. We encourage users to test out M1
, identify potential areas for significant improvement and propose PRs for inclusion to this repo.
Pre-Trained Model using 1950 bpMRI with PI-RADS v2 Annotations [Training:Validation Ratio - 80:20]:
To infer lesion predictions on testing samples using the pre-trained variant (architecture in commit 58b784f) of this algorithm, please visit https://grand-challenge.org/algorithms/prostate-mri-cad-cspca/
Main Scripts
β Preprocessing Functions: tf2.5/scripts/preprocess.py
β Tensor-Based Augmentations: tf2.5/scripts/model/augmentations.py
β Training Script Template: tf2.5/scripts/train_model.py
β Basic Callbacks (e.g. LR Schedules): tf2.5/scripts/callbacks.py
β Loss Functions: tf2.5/scripts/model/losses.py
β Network Architecture: tf2.5/scripts/model/unets/networks.py
Requirements
β Complete Docker Container: anindox8/m1:latest
β Key Python Packages: tf2.5/requirements.txt
Train-time schematic for the hierarchical Bayesian/probabilistic configuration of M1
. L_S
denotes the segmentation loss between prediction p
and ground-truth Y
. Additionally, L_KL
, denoting the KullbackβLeibler divergence loss between prior distribution P
and posterior distribution Q
, is used at train-time (refer to arXiv:1905.13077). For each execution of the model, latent samples z_i β Q
(train-time) or z_i β P
(test-time) are successively drawn at increasing scales of the model to predict one segmentation mask p
.
Architecture schematic of M1
, with attention mechanisms and a nested decoder structure with deep supervision. When dense_skip=False
, all black/skip nodes disappear and M1
simplifies down to an Attention U-Net with SEResNet blocks.
Minimal Example of Model Setup in TensorFlow 2.5:
(More Details: Training CNNs in TF2: Walkthrough; TF2 Datasets: Best Practices; TensorFlow Probability)
# U-Net Definition (Note: Hyperparameters are Data-Centric -> Require Adequate Tuning for Optimal Performance)
unet_model = unets.networks.M1(\
input_spatial_dims = (20,160,160),
input_channels = 3,
num_classes = 2,
filters = (32,64,128,256,512),
strides = ((1,1,1),(1,2,2),(1,2,2),(2,2,2),(2,2,2)),
kernel_sizes = ((1,3,3),(1,3,3),(3,3,3),(3,3,3),(3,3,3)),
prob_latent_dims = (3,2,1,0)
dropout_rate = 0.50,
dropout_mode = 'monte-carlo',
se_reduction = (8,8,8,8,8),
att_sub_samp = ((1,1,1),(1,1,1),(1,1,1),(1,1,1)),
kernel_initializer = tf.keras.initializers.Orthogonal(gain=1),
bias_initializer = tf.keras.initializers.TruncatedNormal(mean=0, stddev=1e-3),
kernel_regularizer = tf.keras.regularizers.l2(1e-4),
bias_regularizer = tf.keras.regularizers.l2(1e-4),
cascaded = False,
dense_skip = True,
probabilistic = True,
deep_supervision = True,
summary = True)
# Schedule Cosine Annealing Learning Rate with Warm Restarts
LR_SCHEDULE = (tf.keras.optimizers.schedules.CosineDecayRestarts(\
initial_learning_rate=1e-3, t_mul=2.00, m_mul=1.00, alpha=1e-3,
first_decay_steps=int(np.ceil(((TRAIN_SAMPLES)/BATCH_SIZE)))*10))
# Compile Model w/ Optimizer and Loss Function(s)
unet_model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=LR_SCHEDULE, amsgrad=True),
loss = losses.Focal(alpha=[0.75, 0.25], gamma=2.00).loss)
# Train Model
unet_model.fit(...)
If you use this repo or some part of its codebase, please cite the following articles (see bibtex):
β A. Saha, J. S. Bosma, J. Linmans, M. Hosseinzadeh, H. Huisman (2021), "Anatomical and Diagnostic Bayesian Segmentation in Prostate MRI βShould Different Clinical Objectives Mandate Different Loss Functions?", Medical Imaging Meets NeurIPS Workshop β 35th Conference on Neural Information Processing Systems (NeurIPS), Sydney, Australia. (architecture in commit 914ec9d)
β A. Saha, M. Hosseinzadeh, H. Huisman (2021), "End-to-End Prostate Cancer Detection in bpMRI via 3D CNNs: Effect of Attention Mechanisms, Clinical Priori and Decoupled False Positive Reduction", Medical Image Analysis:102155. (architecture in commit 58b784f)
β A. Saha, M. Hosseinzadeh, H. Huisman (2020), "Encoding Clinical Priori in 3D Convolutional Neural Networks for Prostate Cancer Detection in bpMRI", Medical Imaging Meets NeurIPS Workshop β 34th Conference on Neural Information Processing Systems (NeurIPS), Vancouver, Canada. (architecture in commit 58b784f)
Contact: [email protected]; [email protected]
Related U-Net Architectures:
β nnU-Net: https://github.com/MIC-DKFZ/nnUNet
β Attention U-Net: https://github.com/ozan-oktay/Attention-Gated-Networks
β UNet++: https://github.com/MrGiovanni/UNetPlusPlus
β Hierarchical Probabilistic U-Net: https://github.com/deepmind/deepmind-research/tree/master/hierarchical_probabilistic_unet