From 8c6c83b6b70fbba1ae4d4aff5b169038bf25b85d Mon Sep 17 00:00:00 2001 From: jaehwan Date: Thu, 9 May 2024 16:08:06 +0900 Subject: [PATCH] rebase code --- README.md | 2 - configs/augmentations.yaml | 10 - configs/config.yaml | 44 -- configs/preprocessing.yaml | 17 - datasets/HaN.py | 112 --- datasets/SamplerFactory.py | 18 - datasets/label_dict.py | 44 -- libs/losses/LossFactory.py | 239 ------- libs/models/AttentionUnet3D.py | 776 -------------------- libs/models/AttentionUnet3D_cross.py | 782 --------------------- libs/models/BaseModelClass.py | 135 ---- libs/models/ER_Net.py | 214 ------ libs/models/HighResNet3D.py | 237 ------- libs/models/ModelFactory.py | 34 - libs/models/Unet3D.py | 214 ------ libs/models/VNet.py | 251 ------- libs/models/mednextv1/MedNextV1.py | 388 ---------- libs/models/mednextv1/blocks.py | 214 ------ libs/models/mednextv1/create_mednext_v1.py | 83 --- libs/models/nonlocalUnet3D.py | 450 ------------ libs/optimizers/OptimizerFactory.py | 23 - libs/schedulers/SchedulerFactory.py | 97 --- preproc.py | 121 ---- train.py | 218 ------ utils/AugmentFactory.py | 34 - utils/EvalFactory.py | 43 -- utils/TaskFactory.py | 46 -- utils/TrainFactory.py | 318 --------- 28 files changed, 5164 deletions(-) delete mode 100644 README.md delete mode 100644 configs/augmentations.yaml delete mode 100644 configs/config.yaml delete mode 100644 configs/preprocessing.yaml delete mode 100644 datasets/HaN.py delete mode 100644 datasets/SamplerFactory.py delete mode 100644 datasets/label_dict.py delete mode 100644 libs/losses/LossFactory.py delete mode 100644 libs/models/AttentionUnet3D.py delete mode 100644 libs/models/AttentionUnet3D_cross.py delete mode 100644 libs/models/BaseModelClass.py delete mode 100644 libs/models/ER_Net.py delete mode 100644 libs/models/HighResNet3D.py delete mode 100644 libs/models/ModelFactory.py delete mode 100644 libs/models/Unet3D.py delete mode 100644 libs/models/VNet.py delete mode 100644 libs/models/mednextv1/MedNextV1.py delete mode 100644 libs/models/mednextv1/blocks.py delete mode 100644 libs/models/mednextv1/create_mednext_v1.py delete mode 100644 libs/models/nonlocalUnet3D.py delete mode 100644 libs/optimizers/OptimizerFactory.py delete mode 100644 libs/schedulers/SchedulerFactory.py delete mode 100644 preproc.py delete mode 100644 train.py delete mode 100644 utils/AugmentFactory.py delete mode 100644 utils/EvalFactory.py delete mode 100644 utils/TaskFactory.py delete mode 100644 utils/TrainFactory.py diff --git a/README.md b/README.md deleted file mode 100644 index 0171a03..0000000 --- a/README.md +++ /dev/null @@ -1,2 +0,0 @@ -# HanSeg_2023 -The Head and Neck oragan-at-risk CT & MR segmentation challenge. Contribution to the Grand Challenge (MICCAI 2023) diff --git a/configs/augmentations.yaml b/configs/augmentations.yaml deleted file mode 100644 index 8ea4d13..0000000 --- a/configs/augmentations.yaml +++ /dev/null @@ -1,10 +0,0 @@ -RandomAffine: - scales: !!python/tuple [ 0.5, 1.5 ] - degrees: !!python/tuple [ -10, 10 ] - isotropic: false - image_interpolation: linear - p: 0.5 - -RandomFlip: - axes: 0 - flip_probability: 0.5 \ No newline at end of file diff --git a/configs/config.yaml b/configs/config.yaml deleted file mode 100644 index 26595a0..0000000 --- a/configs/config.yaml +++ /dev/null @@ -1,44 +0,0 @@ -title: main -project_dir: '/home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/src/HanSeg_2023/experiments' -seed: 42 -device: cuda:2 -experiment: - name: Anchor # Anchor, MidLine, SmallandHard #Segmentation metadata - -data_loader: - dataset: /home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/data/preprocessed/HaN-Seg/set_1/preprocessing_7CB7AA3181 - kfold: 1 # 1 is not use kfold validation other is use kfold validation - augmentations: configs/augmentations.yaml - batch_size: 1 - num_workers: 16 - patch_loader: False - patch_shape: - - 128 - - 128 - - 64 - resize_shape: - - 288 # x - - 288 # y - - 64 # z - sampler_type: UniformSampler - -model: - name: Unet3D - -loss: - name: DiceCELoss - -lr_scheduler: - name: LambdaLR - -optimizer: - learning_rate: 0.001 - name: AdamW - -trainer: - reload: False - checkpoint: '' - do_train: True - do_test: False - do_inference: False - epochs: 1000 diff --git a/configs/preprocessing.yaml b/configs/preprocessing.yaml deleted file mode 100644 index f8bb235..0000000 --- a/configs/preprocessing.yaml +++ /dev/null @@ -1,17 +0,0 @@ -# Description: Preprocessing configuration file -source_dir: /home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/data/raw/HaN-Seg/set_1 -save_dir: /home/jhhan/01_research/01_MICCAI/01_grandchellenge/han_seg/data/preprocessed/HaN-Seg/set_1 - -experiment: - name: Anchor - -preprocessing: - Clamp: - out_min: -500 - out_max: 1000 - Resample: - target: !!python/tuple [1, 1, 1] - RescaleIntensity: - out_min_max: !!python/tuple [0, 1] - -check_preprocessing: True \ No newline at end of file diff --git a/datasets/HaN.py b/datasets/HaN.py deleted file mode 100644 index 81d018e..0000000 --- a/datasets/HaN.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import json -import logging -import logging.config -from pathlib import Path -import nrrd -import numpy as np -import torch -import torchio as tio -from datasets.label_dict import LABEL_dict, Anchor_dict # from datasets/label_dict.py -from torch.utils.data import DataLoader -from datasets.SamplerFactory import SamplerFactory - -class HaN(tio.SubjectsDataset): - """ - MICCAI dataset - """ - def __init__(self, config, splits, transform=None, sampler=None, **kwargs): - self.config = config - self.splits = splits - self.root = Path(self.config.data_loader.dataset) - if not isinstance(splits, list): - splits = [splits] - self.seed = self.config.seed - self.sampler = sampler - subjects_list = self._get_subjects_list(self.root, splits) - super().__init__(subjects_list, transform, **kwargs) - - def _numpy_reader(self, path): - data = torch.from_numpy(np.load(path)).float() - affine = torch.eye(4, requires_grad=False) - return data, affine - - def _split_data(self, data_list): - # train and val data split - np.random.seed(self.seed) - split_ratio = 0.8 - train_size = int(split_ratio * len(data_list)) - val_size = int((len(data_list) - train_size)) - train_data = data_list[:train_size] - val_data = data_list[train_size:train_size+val_size] - return train_data, val_data - - def _generate_jsondata(self, train_data: list, val_data: list, test_data=None): - if test_data: - test_data = test_data - else: - test_data = val_data - - json_data = { - 'train': train_data, - 'val': val_data, - "test": test_data - } - return json_data - - def _get_subjects_list(self, root, splits): - # TODO : check the path - patient_data_list = os.listdir(root) - patient_data_list = [entry for entry in patient_data_list if os.path.isdir(os.path.join(root, entry))] - - # TODO: change the method, reading whole list and split train/val/test and kfold - if self.config.data_loader.kfold == 1: - train_data, val_data = self._split_data(patient_data_list) - json_splits = self._generate_jsondata(train_data, val_data) - - # consists of the data sets - subjects = [] - for split in splits: - for patient in json_splits[split]: - # generate labels - ct_data_path = os.path.join(root, patient, patient + '_IMG_CT.nrrd') - # mr_data_path = os.path.join(root, patient, patient + '_IMG_MR_T1.nrrd') - label_path = os.path.join(root, patient, patient + f'_{self.config.experiment.name}.seg.nrrd') - if not os.path.isfile(ct_data_path): - raise ValueError(f'Missing CT data file for patient {patient} ({ct_data_path})') - # if not os.path.isfile(mr_data_path): - # raise ValueError(f'Missing MR_TI data file for patient {patient} ({mr_data_path})') - if not os.path.isfile(label_path): - raise ValueError(f'Missing LABEL file for patient {patient} ({label_path})') - - subject_dict = { - 'partition': split, - 'patient': patient, - 'ct': tio.ScalarImage(ct_data_path), - # 'mr': tio.ScalarImage(mr_data_path, reader=self._nrrd_reader), - 'label': tio.LabelMap(label_path,), - } - - subjects.append(tio.Subject(**subject_dict)) - print(f"Loaded {len(subjects)} patients for split {split}") - return subjects - - def get_loader(self, config): - # patch-based training - if config.patch_loader: - sampler = SamplerFactory(config).get() - queue = tio.Queue( - subjects_dataset=self, - max_length=300, - samples_per_volume=10, - sampler=sampler, - num_workers=config.num_workers, - shuffle_subjects=True, - shuffle_patches=True, - start_background=False, - ) - loader = DataLoader(queue, batch_size=config.batch_size, num_workers=0, pin_memory=True) - else: # subject-based training - dataset = tio.SubjectsDataset(self._subjects, transform=self._transform) - loader = DataLoader(dataset, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=True) - return loader diff --git a/datasets/SamplerFactory.py b/datasets/SamplerFactory.py deleted file mode 100644 index bd57dd8..0000000 --- a/datasets/SamplerFactory.py +++ /dev/null @@ -1,18 +0,0 @@ -import torchio as tio - -class SamplerFactory: - def __init__(self, config): - self.config = config - # the config is not whole config, it is "self.config.data_loader" - self.sampler_type = config.sampler_type - - def get(self): - if self.sampler_type == 'UniformSampler': - sampler = tio.UniformSampler(patch_size=self.config.patch_shape) - # elif self.sampler_type == 'WeightedSampler': - # sampler = tio.WeightedSampler(patch_size=self.config.patch_shape, probability_map='sampling_map') - elif self.sampler_type == 'WeightedSampler': - probabilities = {0: 0.1, 1: 0.9} - sampler = tio.LabelSampler(patch_size=self.config.patch_shape, label_name='label', label_probabilities=probabilities) - - return sampler \ No newline at end of file diff --git a/datasets/label_dict.py b/datasets/label_dict.py deleted file mode 100644 index fc86510..0000000 --- a/datasets/label_dict.py +++ /dev/null @@ -1,44 +0,0 @@ -LABEL_dict = { - "background": 0, - "A_Carotid_L": 1, - "A_Carotid_R": 2, - "Arytenoid": 3, - "Bone_Mandible": 4, - "Brainstem": 5, - "BuccalMucosa": 6, - "Cavity_Oral": 7, - "Cochlea_L": 8, - "Cochlea_R": 9, - "Cricopharyngeus": 10, - "Esophagus_S": 11, - "Eye_AL": 12, - "Eye_AR": 13, - "Eye_PL": 14, - "Eye_PR": 15, - "Glnd_Lacrimal_L": 16, - "Glnd_Lacrimal_R": 17, - "Glnd_Submand_L": 18, - "Glnd_Submand_R": 19, - "Glnd_Thyroid": 20, - "Glottis": 21, - "Larynx_SG": 22, - "Lips": 23, - "OpticChiasm": 24, - "OpticNrv_L": 25, - "OpticNrv_R": 26, - "Parotid_L": 27, - "Parotid_R": 28, - "Pituitary": 29, - "SpinalCord": 30, -} - -Anchor_dict = { - "background": 0, # 0 - "Bone_Mandible": 4, # 1 - "Brainstem": 5, # 2 - "Eye_AL": 12, # 3 - "Eye_AR": 13, # 4 - "Eye_PL": 14, # 5 - "Eye_PR": 15, # 6 - "SpinalCord": 30, # 7 -} \ No newline at end of file diff --git a/libs/losses/LossFactory.py b/libs/losses/LossFactory.py deleted file mode 100644 index 6cecacb..0000000 --- a/libs/losses/LossFactory.py +++ /dev/null @@ -1,239 +0,0 @@ -import numpy as np -import torch -from torch import nn, Tensor -import torch.nn.functional as F -from monai.losses import DiceCELoss - -class LossFactory: - def __init__(self, names, classes, weights=None): - self.names = names - if not isinstance(self.names, list): - self.names = [self.names] - - print(f'Losses used: {self.names}') - self.classes = classes - self.weights = weights - self.losses = {} - for name in self.names: - loss = self.get_loss(name) - self.losses[name] = loss - - def get_loss(self, name): - if name == 'JaccardLoss': - loss_fn = JaccardLoss(weight=self.weights) - elif name == 'Dice3DLoss': - loss_fn = Dice3DLoss(weight=self.weights) - elif name == 'DiceCELoss': - loss_fn = DiceCELoss(self.classes, to_onehot_y=False, softmax=True) - elif name == 'DiceLoss': - loss_fn = DiceLoss(self.classes) - elif name == 'BoundaryLoss': - loss_fn = BoundaryLoss(self.classes) - else: - raise Exception(f"Loss function {name} can't be found.") - return loss_fn - -class JaccardLoss(torch.nn.Module): - def __init__(self, weight=None, size_average=True, per_volume=False, apply_sigmoid=True, - min_pixels=5): - super().__init__() - self.size_average = size_average - self.weight = weight - self.per_volume = per_volume - self.apply_sigmoid = apply_sigmoid - self.min_pixels = min_pixels - - def forward(self, pred, gt): - assert pred.shape[1] == 1, 'this loss works with a binary prediction' - if self.apply_sigmoid: - pred = torch.sigmoid(pred) - - batch_size = pred.size()[0] - eps = 1e-6 - if not self.per_volume: - batch_size = 1 - dice_gt = gt.contiguous().view(batch_size, -1).float() - dice_pred = pred.contiguous().view(batch_size, -1) - intersection = torch.sum(dice_pred * dice_gt, dim=1) - union = torch.sum(dice_pred + dice_gt, dim=1) - intersection - loss = 1 - (intersection + eps) / (union + eps) - return loss - -class DiceLoss(nn.Module): - # TODO: Check about partition_weights, see original code - # what i didn't understand is that for dice loss, partition_weights gets - # multiplied inside the forward and also in the factory_loss function - # I think that this is wrong, and removed it from the forward - def __init__(self, classes): - super().__init__() - self.eps = 1e-06 - self.classes = classes - - def forward(self, pred, gt): - # included = [v for k, v in self.classes.items() if k not in ['UNLABELED']] - gt_onehot = torch.nn.functional.one_hot(gt.squeeze().long(), num_classes=self.classes) - # if gt.shape[0] = 1: # we need to add a further axis after the previous squeeze() - gt_onehot = gt_onehot.unsqueeze(0) - gt_onehot = torch.movedim(gt_onehot, -1, 1) - input_soft = F.softmax(pred, dim=1) - dims = (2, 3, 4) - - intersection = torch.sum(input_soft * gt_onehot, dims) - cardinality = torch.sum(input_soft + gt_onehot, dims) - dice_score = 2. * intersection / (cardinality + self.eps) - return 1. - torch.mean(dice_score) - - - - - -def flatten(tensor): - """Flattens a given tensor such that the channel axis is first. - The shapes are transformed as follows: - (N, C, D, H, W) -> (C, N * D * H * W) - """ - # number of channels - C = tensor.size(1) - # new axis order - axis_order = (1, 0) + tuple(range(2, tensor.dim())) - # Transpose: (N, C, D, H, W) -> (C, N, D, H, W) - transposed = tensor.permute(axis_order) - # Flatten: (C, N, D, H, W) -> (C, N * D * H * W) - return transposed.contiguous().view(C, -1) - - -def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): - """ - Computes DiceCoefficient as defined in https://arxiv.org/abs/1606.04797 given a multi channel input and target. - Assumes the input is a normalized probability, e.g. a result of Sigmoid or Softmax function. - Ref: https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/losses.py - - Args: - input (torch.Tensor): NxCxSpatial input tensor - target (torch.Tensor): NxCxSpatial target tensor - epsilon (float): prevents division by zero - weight (torch.Tensor): Cx1 tensor of weight per channel/class - """ - - # input and target shapes must match - assert input.size() == target.size(), "'input' and 'target' must have the same shape" - - input = flatten(input) - target = flatten(target) - target = target.float() - - # compute per channel Dice Coefficient - intersect = (input * target).sum(-1) - if weight is not None: - intersect = weight * intersect - - # here we can use standard dice (input + target).sum(-1) or extension (see V-Net) (input^2 + target^2).sum(-1) - denominator = (input * input).sum(-1) + (target * target).sum(-1) - # return 2 * (intersect / denominator.clamp(min=epsilon)) - return 2 * ((intersect + epsilon) / (denominator + epsilon)) - - -class Dice3DLoss(torch.nn.Module): - """Computes Dice Loss according to https://arxiv.org/abs/1606.04797. - For multi-class segmentation `weight` parameter can be used to assign different weights per class. - - Args: - input (torch.Tensor): NxCxDxHxW input tensor - target (torch.Tensor): NxCxDxHxW target tensor - """ - - def __init__(self, weight=None): - super(Dice3DLoss, self).__init__() - self.register_buffer('weight', weight) - self.weight = weight - - def forward(self, input, target): - # compute per channel Dice coefficient - per_channel_dice = compute_per_channel_dice(input, target, weight=self.weight) - - # average Dice score across all channels/classes - return 1. - torch.mean(per_channel_dice), per_channel_dice - - -class BoundaryLoss(torch.nn.Module): - def __init__(self, num_classes=2): - super(BoundaryLoss, self).__init__() - self.num_classes = num_classes - - def DiceCoeff(self, pred, target, smooth=1e-7): - inter = (pred * target).sum() - return (2 * inter + smooth) / (pred.sum() + target.sum() + smooth) - - def DiceLoss(self, pred, target, smooth=1e-7): - return torch.sum(1 - self.DiceCoeff(pred, target, smooth)) - - def extract_surface(self, volume): - # Pad the volume with zeros on all sides - padded_volume = torch.nn.functional.pad(volume, (1, 1, 1, 1, 1, 1), mode='constant', value=0) - - # Compute the gradient along all three dimensions - dz = padded_volume[1:-1, 1:-1, 2:] - padded_volume[1:-1, 1:-1, :-2] - dy = padded_volume[1:-1, 2:, 1:-1] - padded_volume[1:-1, :-2, 1:-1] - dx = padded_volume[2:, 1:-1, 1:-1] - padded_volume[:-2, 1:-1, 1:-1] - - # Compute the magnitude of the gradient vector - mag = torch.sqrt(dx ** 2 + dy ** 2 + dz ** 2) - mag[mag > 0] = 1 - return mag - - def forward(self, preds, targets): - # sigmoid probability map - preds_ = torch.sigmoid(preds) - p_set = self.extract_surface(preds_[0, 0, ...]).cuda() - t_set = self.extract_surface(targets[0, 0, ...]).cuda() - loss = self.DiceLoss(p_set, t_set) - - return loss - -class CrossEntropyLoss(torch.nn.Module): - def __init__(self, weights=None, apply_sigmoid=True): - super().__init__() - self.weights = weights - self.apply_sigmoid = apply_sigmoid - self.loss_fn = nn.CrossEntropyLoss(weight=self.weights) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, pred, gt): - pred = self.sigmoid(pred) - return self.loss_fn(pred, gt) - - -class BoundaryLoss(torch.nn.Module): - def __init__(self, num_classes=2): - super(BoundaryLoss, self).__init__() - self.num_classes = num_classes - - def DiceCoeff(self, pred, target, smooth=1e-7): - inter = (pred * target).sum() - return (2 * inter + smooth) / (pred.sum() + target.sum() + smooth) - - def DiceLoss(self, pred, target, smooth=1e-7): - return torch.sum(1 - self.DiceCoeff(pred, target, smooth)) - - def extract_surface(self, volume): - # Pad the volume with zeros on all sides - padded_volume = torch.nn.functional.pad(volume, (1, 1, 1, 1, 1, 1), mode='constant', value=0) - - # Compute the gradient along all three dimensions - dz = padded_volume[1:-1, 1:-1, 2:] - padded_volume[1:-1, 1:-1, :-2] - dy = padded_volume[1:-1, 2:, 1:-1] - padded_volume[1:-1, :-2, 1:-1] - dx = padded_volume[2:, 1:-1, 1:-1] - padded_volume[:-2, 1:-1, 1:-1] - - # Compute the magnitude of the gradient vector - mag = torch.sqrt(dx ** 2 + dy ** 2 + dz ** 2) - mag[mag > 0] = 1 - return mag - - def forward(self, preds, targets): - # sigmoid probability map - preds_ = torch.sigmoid(preds) - p_set = self.extract_surface(preds_[0, 0, ...]).cuda() - t_set = self.extract_surface(targets[0, 0, ...]).cuda() - loss = self.DiceLoss(p_set, t_set) - - return loss \ No newline at end of file diff --git a/libs/models/AttentionUnet3D.py b/libs/models/AttentionUnet3D.py deleted file mode 100644 index 3f59ea2..0000000 --- a/libs/models/AttentionUnet3D.py +++ /dev/null @@ -1,776 +0,0 @@ -from torch.nn import init -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class _GridAttentionBlockND(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(_GridAttentionBlockND, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] - - # Downsampling rate for the input featuremap - if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor - elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) - else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # Output transform - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0, bias=True) - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - # Define the operation - if mode == 'concatenation': - self.operation_function = self._concatenation - elif mode == 'concatenation_debug': - self.operation_function = self._concatenation_debug - elif mode == 'concatenation_residual': - self.operation_function = self._concatenation_residual - else: - raise NotImplementedError('Unknown operation function.') - - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - def _concatenation_debug(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.softplus(theta_x + phi_g) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - - def _concatenation_residual(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - f = self.psi(f).view(batch_size, 1, -1) - sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock3D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(GridAttentionBlock3D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - ) -def weights_init_kaiming(m): - classname = m.__class__.__name__ - # print(classname) - if classname.find('Conv') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif classname.find('Linear') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif classname.find('BatchNorm') != -1: - init.normal_(m.weight.data, 1.0, 0.02) - init.constant_(m.bias.data, 0.0) - - -def init_weights(net, init_type='normal'): - if init_type == 'kaiming': - net.apply(weights_init_kaiming) - -class conv2DBatchNorm(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(conv2DBatchNorm, self).__init__() - - self.cb_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) - - def forward(self, inputs): - outputs = self.cb_unit(inputs) - return outputs - - -class deconv2DBatchNorm(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(deconv2DBatchNorm, self).__init__() - - self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) - - def forward(self, inputs): - outputs = self.dcb_unit(inputs) - return outputs - - -class conv2DBatchNormRelu(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(conv2DBatchNormRelu, self).__init__() - - self.cbr_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) - - def forward(self, inputs): - outputs = self.cbr_unit(inputs) - return outputs - - -class deconv2DBatchNormRelu(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(deconv2DBatchNormRelu, self).__init__() - - self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) - - def forward(self, inputs): - outputs = self.dcbr_unit(inputs) - return outputs - - -class unetConv2(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): - super(unetConv2, self).__init__() - self.n = n - self.ks = ks - self.stride = stride - self.padding = padding - s = stride - p = padding - if is_batchnorm: - for i in range(1, n+1): - conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), - nn.BatchNorm2d(out_size), - nn.ReLU(inplace=True),) - setattr(self, 'conv%d'%i, conv) - in_size = out_size - - else: - for i in range(1, n+1): - conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), - nn.ReLU(inplace=True),) - setattr(self, 'conv%d'%i, conv) - in_size = out_size - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - x = inputs - for i in range(1, self.n+1): - conv = getattr(self, 'conv%d'%i) - x = conv(x) - - return x - - -class UnetConv3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): - super(UnetConv3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - outputs = self.conv2(outputs) - return outputs - - -class FCNConv3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): - super(FCNConv3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - outputs = self.conv2(outputs) - outputs = self.conv3(outputs) - return outputs - - -class UnetGatingSignal3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm): - super(UnetGatingSignal3, self).__init__() - self.fmap_size = (4, 4, 4) - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), - nn.InstanceNorm3d(in_size//2), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool3d(output_size=self.fmap_size), - ) - self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], - out_features=out_size, bias=True) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool3d(output_size=self.fmap_size), - ) - self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], - out_features=out_size, bias=True) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - batch_size = inputs.size(0) - outputs = self.conv1(inputs) - outputs = outputs.view(batch_size, -1) - outputs = self.fc1(outputs) - return outputs - - -class UnetGridGatingSignal3(nn.Module): - def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True): - super(UnetGridGatingSignal3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True), - ) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), - nn.ReLU(inplace=True), - ) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - return outputs - - -class unetUp(nn.Module): - def __init__(self, in_size, out_size, is_deconv): - super(unetUp, self).__init__() - self.conv = unetConv2(in_size, out_size, False) - if is_deconv: - self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) - else: - self.up = nn.UpsamplingBilinear2d(scale_factor=2) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('unetConv2') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -class UnetUp3(nn.Module): - def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): - super(UnetUp3, self).__init__() - if is_deconv: - self.conv = UnetConv3(in_size, out_size, is_batchnorm) - self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) - else: - self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) - self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -class UnetUp3_CT(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm=True): - super(UnetUp3_CT, self).__init__() - self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -# Squeeze-and-Excitation Network -class SqEx(nn.Module): - - def __init__(self, n_features, reduction=6): - super(SqEx, self).__init__() - - if n_features % reduction != 0: - raise ValueError('n_features must be divisible by reduction (default = 4)') - - self.linear1 = nn.Linear(n_features, n_features // reduction, bias=False) - self.nonlin1 = nn.ReLU(inplace=True) - self.linear2 = nn.Linear(n_features // reduction, n_features, bias=False) - self.nonlin2 = nn.Sigmoid() - - def forward(self, x): - - y = F.avg_pool3d(x, kernel_size=x.size()[2:5]) - y = y.permute(0, 2, 3, 4, 1) - y = self.nonlin1(self.linear1(y)) - y = self.nonlin2(self.linear2(y)) - y = y.permute(0, 4, 1, 2, 3) - y = x * y - return y - - -class UnetUp3_SqEx(nn.Module): - def __init__(self, in_size, out_size, is_deconv, is_batchnorm): - super(UnetUp3_SqEx, self).__init__() - if is_deconv: - self.sqex = SqEx(n_features=in_size+out_size) - self.conv = UnetConv3(in_size, out_size, is_batchnorm) - self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) - else: - self.sqex = SqEx(n_features=in_size+out_size) - self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) - self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - concat = torch.cat([outputs1, outputs2], 1) - gated = self.sqex(concat) - return self.conv(gated) - - -class residualBlock(nn.Module): - expansion = 1 - - def __init__(self, in_channels, n_filters, stride=1, downsample=None): - super(residualBlock, self).__init__() - - self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) - self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) - self.downsample = downsample - self.stride = stride - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - residual = x - - out = self.convbnrelu1(x) - out = self.convbn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - return out - - -class residualBottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_channels, n_filters, stride=1, downsample=None): - super(residualBottleneck, self).__init__() - self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) - self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) - self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.convbn1(x) - out = self.convbn2(out) - out = self.convbn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - - - -class SeqModelFeatureExtractor(nn.Module): - def __init__(self, submodule, extracted_layers): - super(SeqModelFeatureExtractor, self).__init__() - - self.submodule = submodule - self.extracted_layers = extracted_layers - - def forward(self, x): - outputs = [] - for name, module in self.submodule._modules.items(): - x = module(x) - if name in self.extracted_layers: - outputs += [x] - return outputs + [x] - - -class HookBasedFeatureExtractor(nn.Module): - def __init__(self, submodule, layername, upscale=False): - super(HookBasedFeatureExtractor, self).__init__() - - self.submodule = submodule - self.submodule.eval() - self.layername = layername - self.outputs_size = None - self.outputs = None - self.inputs = None - self.inputs_size = None - self.upscale = upscale - - def get_input_array(self, m, i, o): - if isinstance(i, tuple): - self.inputs = [i[index].data.clone() for index in range(len(i))] - self.inputs_size = [input.size() for input in self.inputs] - else: - self.inputs = i.data.clone() - self.inputs_size = self.input.size() - print('Input Array Size: ', self.inputs_size) - - def get_output_array(self, m, i, o): - if isinstance(o, tuple): - self.outputs = [o[index].data.clone() for index in range(len(o))] - self.outputs_size = [output.size() for output in self.outputs] - else: - self.outputs = o.data.clone() - self.outputs_size = self.outputs.size() - print('Output Array Size: ', self.outputs_size) - - def rescale_output_array(self, newsize): - us = nn.Upsample(size=newsize[2:], mode='bilinear') - if isinstance(self.outputs, list): - for index in range(len(self.outputs)): self.outputs[index] = us(self.outputs[index]).data() - else: - self.outputs = us(self.outputs).data() - - def forward(self, x): - target_layer = self.submodule._modules.get(self.layername) - - # Collect the output tensor - h_inp = target_layer.register_forward_hook(self.get_input_array) - h_out = target_layer.register_forward_hook(self.get_output_array) - self.submodule(x) - h_inp.remove() - h_out.remove() - - # Rescale the feature-map if it's required - if self.upscale: self.rescale_output_array(x.size()) - - return self.inputs, self.outputs - - -class UnetDsv3(nn.Module): - def __init__(self, in_size, out_size, scale_factor): - super(UnetDsv3, self).__init__() - self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), - nn.Upsample(scale_factor=scale_factor, mode='trilinear'), ) - - def forward(self, input): - return self.dsv(input) - -class Attention_UNet3D(nn.Module): - - def __init__(self, feature_scale=4, n_classes=1, is_deconv=True, in_channels=1, - nonlocal_mode='concatenation', attention_dsample=(2, 2, 2), is_batchnorm=True): - super(Attention_UNet3D, self).__init__() - self.is_deconv = is_deconv - self.in_channels = in_channels - self.is_batchnorm = is_batchnorm - self.feature_scale = feature_scale - - filters = [64, 128, 256, 512, 1024] - filters = [int(x / self.feature_scale) for x in filters] - - # downsampling - self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) - - # attention blocks - self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - - # upsampling - self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) - self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) - self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) - self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) - - # deep supervision - self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) - self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) - self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) - self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) - - # final conv (without any concat) - self.final = nn.Conv3d(n_classes*4, n_classes, 1) - - # initialise weights - for m in self.modules(): - if isinstance(m, nn.Conv3d): - init_weights(m, init_type='kaiming') - elif isinstance(m, nn.BatchNorm3d): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - # Feature Extraction - conv1 = self.conv1(inputs) - maxpool1 = self.maxpool1(conv1) - - conv2 = self.conv2(maxpool1) - maxpool2 = self.maxpool2(conv2) - - conv3 = self.conv3(maxpool2) - maxpool3 = self.maxpool3(conv3) - - conv4 = self.conv4(maxpool3) - maxpool4 = self.maxpool4(conv4) - - # Gating Signal Generation - center = self.center(maxpool4) - gating = self.gating(center) - - # Attention Mechanism - # Upscaling Part (Decoder) - g_conv4, att4 = self.attentionblock4(conv4, gating) - up4 = self.up_concat4(g_conv4, center) - g_conv3, att3 = self.attentionblock3(conv3, up4) - up3 = self.up_concat3(g_conv3, up4) - g_conv2, att2 = self.attentionblock2(conv2, up3) - up2 = self.up_concat2(g_conv2, up3) - up1 = self.up_concat1(conv1, up2) - - # Deep Supervision - dsv4 = self.dsv4(up4) - dsv3 = self.dsv3(up3) - dsv2 = self.dsv2(up2) - dsv1 = self.dsv1(up1) - final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) - - return final - - - @staticmethod - def apply_argmax_softmax(pred): - log_p = F.softmax(pred, dim=1) - - return log_p - - -class MultiAttentionBlock(nn.Module): - def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): - super(MultiAttentionBlock, self).__init__() - self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor= sub_sample_factor) - self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), - nn.BatchNorm3d(in_size), - nn.ReLU(inplace=True) - ) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, input, gating_signal): - gate_1, attention_1 = self.gate_block_1(input, gating_signal) - gate_2, attention_2 = self.gate_block_2(input, gating_signal) - - return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) - - -if __name__ == '__main__': - model = Attention_UNet3D() \ No newline at end of file diff --git a/libs/models/AttentionUnet3D_cross.py b/libs/models/AttentionUnet3D_cross.py deleted file mode 100644 index a6b8a7b..0000000 --- a/libs/models/AttentionUnet3D_cross.py +++ /dev/null @@ -1,782 +0,0 @@ -from torch.nn import init -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class _GridAttentionBlockND(nn.Module): - def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(_GridAttentionBlockND, self).__init__() - - assert dimension in [2, 3] - assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual'] - - # Downsampling rate for the input featuremap - if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor - elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor) - else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension - - # Default parameter set - self.mode = mode - self.dimension = dimension - self.sub_sample_kernel_size = self.sub_sample_factor - - # Number of channels (pixel dimensions) - self.in_channels = in_channels - self.gating_channels = gating_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - bn = nn.BatchNorm3d - self.upsample_mode = 'trilinear' - elif dimension == 2: - conv_nd = nn.Conv2d - bn = nn.BatchNorm2d - self.upsample_mode = 'bilinear' - else: - raise NotImplemented - - # Output transform - self.W = nn.Sequential( - conv_nd(in_channels=self.in_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0), - bn(self.in_channels), - ) - - # Theta^T * x_ij + Phi^T * gating_signal + bias - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False) - self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0, bias=True) - self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - # Define the operation - if mode == 'concatenation': - self.operation_function = self._concatenation - elif mode == 'concatenation_debug': - self.operation_function = self._concatenation_debug - elif mode == 'concatenation_residual': - self.operation_function = self._concatenation_residual - else: - raise NotImplementedError('Unknown operation function.') - - - def forward(self, x, g): - ''' - :param x: (b, c, t, h, w) - :param g: (b, g_d) - :return: - ''' - # g: torch.Size([1, 256, 8, 10, 8]) x: torch.Size([1, 129, 16, 20, 16]) - output = self.operation_function(x, g) - return output - - def _concatenation(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - def _concatenation_debug(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.softplus(theta_x + phi_g) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - sigm_psi_f = F.sigmoid(self.psi(f)) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - - def _concatenation_residual(self, x, g): - input_size = x.size() - batch_size = input_size[0] - assert batch_size == g.size(0) - - # theta => (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw) - # phi => (b, g_d) -> (b, i_c) - theta_x = self.theta(x) - theta_x_size = theta_x.size() - - # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w') - # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c, t/s1, h/s2, w/s3) - phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode) - f = F.relu(theta_x + phi_g, inplace=True) - - # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3) - f = self.psi(f).view(batch_size, 1, -1) - sigm_psi_f = F.softmax(f, dim=2).view(batch_size, 1, *theta_x.size()[2:]) - - # upsample the attentions and multiply - sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode) - y = sigm_psi_f.expand_as(x) * x - W_y = self.W(y) - - return W_y, sigm_psi_f - - -class GridAttentionBlock3D(_GridAttentionBlockND): - def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation', - sub_sample_factor=(2,2,2)): - super(GridAttentionBlock3D, self).__init__(in_channels, - inter_channels=inter_channels, - gating_channels=gating_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - ) -def weights_init_kaiming(m): - classname = m.__class__.__name__ - # print(classname) - if classname.find('Conv') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif classname.find('Linear') != -1: - init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') - elif classname.find('BatchNorm') != -1: - init.normal_(m.weight.data, 1.0, 0.02) - init.constant_(m.bias.data, 0.0) - - -def init_weights(net, init_type='normal'): - if init_type == 'kaiming': - net.apply(weights_init_kaiming) - -class conv2DBatchNorm(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(conv2DBatchNorm, self).__init__() - - self.cb_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) - - def forward(self, inputs): - outputs = self.cb_unit(inputs) - return outputs - - -class deconv2DBatchNorm(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(deconv2DBatchNorm, self).__init__() - - self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)),) - - def forward(self, inputs): - outputs = self.dcb_unit(inputs) - return outputs - - -class conv2DBatchNormRelu(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(conv2DBatchNormRelu, self).__init__() - - self.cbr_unit = nn.Sequential(nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) - - def forward(self, inputs): - outputs = self.cbr_unit(inputs) - return outputs - - -class deconv2DBatchNormRelu(nn.Module): - def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): - super(deconv2DBatchNormRelu, self).__init__() - - self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, - padding=padding, stride=stride, bias=bias), - nn.BatchNorm2d(int(n_filters)), - nn.ReLU(inplace=True),) - - def forward(self, inputs): - outputs = self.dcbr_unit(inputs) - return outputs - - -class unetConv2(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1): - super(unetConv2, self).__init__() - self.n = n - self.ks = ks - self.stride = stride - self.padding = padding - s = stride - p = padding - if is_batchnorm: - for i in range(1, n+1): - conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), - nn.BatchNorm2d(out_size), - nn.ReLU(inplace=True),) - setattr(self, 'conv%d'%i, conv) - in_size = out_size - - else: - for i in range(1, n+1): - conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p), - nn.ReLU(inplace=True),) - setattr(self, 'conv%d'%i, conv) - in_size = out_size - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - x = inputs - for i in range(1, self.n+1): - conv = getattr(self, 'conv%d'%i) - x = conv(x) - - return x - - -class UnetConv3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): - super(UnetConv3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - outputs = self.conv2(outputs) - return outputs - - -class FCNConv3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): - super(FCNConv3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True),) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - self.conv3 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - outputs = self.conv2(outputs) - outputs = self.conv3(outputs) - return outputs - - -class UnetGatingSignal3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm): - super(UnetGatingSignal3, self).__init__() - self.fmap_size = (4, 4, 4) - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), - nn.InstanceNorm3d(in_size//2), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool3d(output_size=self.fmap_size), - ) - self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], - out_features=out_size, bias=True) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, in_size//2, (1,1,1), (1,1,1), (0,0,0)), - nn.ReLU(inplace=True), - nn.AdaptiveAvgPool3d(output_size=self.fmap_size), - ) - self.fc1 = nn.Linear(in_features=(in_size//2) * self.fmap_size[0] * self.fmap_size[1] * self.fmap_size[2], - out_features=out_size, bias=True) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - batch_size = inputs.size(0) - outputs = self.conv1(inputs) - outputs = outputs.view(batch_size, -1) - outputs = self.fc1(outputs) - return outputs - - -class UnetGridGatingSignal3(nn.Module): - def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True): - super(UnetGridGatingSignal3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), - nn.InstanceNorm3d(out_size), - nn.ReLU(inplace=True), - ) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)), - nn.ReLU(inplace=True), - ) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - return outputs - - -class unetUp(nn.Module): - def __init__(self, in_size, out_size, is_deconv): - super(unetUp, self).__init__() - self.conv = unetConv2(in_size, out_size, False) - if is_deconv: - self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1) - else: - self.up = nn.UpsamplingBilinear2d(scale_factor=2) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('unetConv2') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -class UnetUp3(nn.Module): - def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): - super(UnetUp3, self).__init__() - if is_deconv: - self.conv = UnetConv3(in_size, out_size, is_batchnorm) - self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) - else: - self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) - self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -class UnetUp3_CT(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm=True): - super(UnetUp3_CT, self).__init__() - self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -# Squeeze-and-Excitation Network -class SqEx(nn.Module): - - def __init__(self, n_features, reduction=6): - super(SqEx, self).__init__() - - if n_features % reduction != 0: - raise ValueError('n_features must be divisible by reduction (default = 4)') - - self.linear1 = nn.Linear(n_features, n_features // reduction, bias=False) - self.nonlin1 = nn.ReLU(inplace=True) - self.linear2 = nn.Linear(n_features // reduction, n_features, bias=False) - self.nonlin2 = nn.Sigmoid() - - def forward(self, x): - - y = F.avg_pool3d(x, kernel_size=x.size()[2:5]) - y = y.permute(0, 2, 3, 4, 1) - y = self.nonlin1(self.linear1(y)) - y = self.nonlin2(self.linear2(y)) - y = y.permute(0, 4, 1, 2, 3) - y = x * y - return y - - -class UnetUp3_SqEx(nn.Module): - def __init__(self, in_size, out_size, is_deconv, is_batchnorm): - super(UnetUp3_SqEx, self).__init__() - if is_deconv: - self.sqex = SqEx(n_features=in_size+out_size) - self.conv = UnetConv3(in_size, out_size, is_batchnorm) - self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) - else: - self.sqex = SqEx(n_features=in_size+out_size) - self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) - self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - concat = torch.cat([outputs1, outputs2], 1) - gated = self.sqex(concat) - return self.conv(gated) - - -class residualBlock(nn.Module): - expansion = 1 - - def __init__(self, in_channels, n_filters, stride=1, downsample=None): - super(residualBlock, self).__init__() - - self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) - self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) - self.downsample = downsample - self.stride = stride - self.relu = nn.ReLU(inplace=True) - - def forward(self, x): - residual = x - - out = self.convbnrelu1(x) - out = self.convbn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - return out - - -class residualBottleneck(nn.Module): - expansion = 4 - - def __init__(self, in_channels, n_filters, stride=1, downsample=None): - super(residualBottleneck, self).__init__() - self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) - self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) - self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.convbn1(x) - out = self.convbn2(out) - out = self.convbn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - - - -class SeqModelFeatureExtractor(nn.Module): - def __init__(self, submodule, extracted_layers): - super(SeqModelFeatureExtractor, self).__init__() - - self.submodule = submodule - self.extracted_layers = extracted_layers - - def forward(self, x): - outputs = [] - for name, module in self.submodule._modules.items(): - x = module(x) - if name in self.extracted_layers: - outputs += [x] - return outputs + [x] - - -class HookBasedFeatureExtractor(nn.Module): - def __init__(self, submodule, layername, upscale=False): - super(HookBasedFeatureExtractor, self).__init__() - - self.submodule = submodule - self.submodule.eval() - self.layername = layername - self.outputs_size = None - self.outputs = None - self.inputs = None - self.inputs_size = None - self.upscale = upscale - - def get_input_array(self, m, i, o): - if isinstance(i, tuple): - self.inputs = [i[index].data.clone() for index in range(len(i))] - self.inputs_size = [input.size() for input in self.inputs] - else: - self.inputs = i.data.clone() - self.inputs_size = self.input.size() - print('Input Array Size: ', self.inputs_size) - - def get_output_array(self, m, i, o): - if isinstance(o, tuple): - self.outputs = [o[index].data.clone() for index in range(len(o))] - self.outputs_size = [output.size() for output in self.outputs] - else: - self.outputs = o.data.clone() - self.outputs_size = self.outputs.size() - print('Output Array Size: ', self.outputs_size) - - def rescale_output_array(self, newsize): - us = nn.Upsample(size=newsize[2:], mode='bilinear') - if isinstance(self.outputs, list): - for index in range(len(self.outputs)): self.outputs[index] = us(self.outputs[index]).data() - else: - self.outputs = us(self.outputs).data() - - def forward(self, x): - target_layer = self.submodule._modules.get(self.layername) - - # Collect the output tensor - h_inp = target_layer.register_forward_hook(self.get_input_array) - h_out = target_layer.register_forward_hook(self.get_output_array) - self.submodule(x) - h_inp.remove() - h_out.remove() - - # Rescale the feature-map if it's required - if self.upscale: self.rescale_output_array(x.size()) - - return self.inputs, self.outputs - - -class UnetDsv3(nn.Module): - def __init__(self, in_size, out_size, scale_factor): - super(UnetDsv3, self).__init__() - self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0), - nn.Upsample(scale_factor=scale_factor, mode='trilinear'), ) - - def forward(self, input): - return self.dsv(input) - -class Attention_UNet3D(nn.Module): - - def __init__(self, emb_shape, feature_scale=4, n_classes=1, is_deconv=True, in_channels=1, - nonlocal_mode='concatenation', attention_dsample=(2, 2, 2), is_batchnorm=True): - super(Attention_UNet3D, self).__init__() - self.is_deconv = is_deconv - self.in_channels = in_channels - self.is_batchnorm = is_batchnorm - self.feature_scale = feature_scale - # edited - self.emb_shape = torch.as_tensor(emb_shape) - self.pos_emb_layer = nn.Linear(6, torch.prod(self.emb_shape).item()) - - filters = [64, 128, 256, 512, 1024] - filters = [int(x / self.feature_scale) for x in filters] - - # downsampling - self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) - - self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) - self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) - - # attention blocks - self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - self.attentionblock4 = MultiAttentionBlock(in_size=filters[3]+1, gate_size=filters[4], inter_size=filters[3], - nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) - - # upsampling - self.up_concat4 = UnetUp3_CT(filters[4]+1, filters[3], is_batchnorm) - self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) - self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) - self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) - - # deep supervision - self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) - self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) - self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) - self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) - - # final conv (without any concat) - self.final = nn.Conv3d(n_classes*4, n_classes, 1) - - # initialise weights - for m in self.modules(): - if isinstance(m, nn.Conv3d): - init_weights(m, init_type='kaiming') - elif isinstance(m, nn.BatchNorm3d): - init_weights(m, init_type='kaiming') - - def forward(self, inputs, emb_codes): - # Feature Extraction - conv1 = self.conv1(inputs) - maxpool1 = self.maxpool1(conv1) - - conv2 = self.conv2(maxpool1) - maxpool2 = self.maxpool2(conv2) - - conv3 = self.conv3(maxpool2) - maxpool3 = self.maxpool3(conv3) - - conv4 = self.conv4(maxpool3) - maxpool4 = self.maxpool4(conv4) - - # Gating Signal Generation - center = self.center(maxpool4) - gating = self.gating(center) - - # cross attention - emb_pos = self.pos_emb_layer(emb_codes).view(-1, 1, *self.emb_shape) - conv4 = torch.cat((conv4, emb_pos), dim=1) - # Attention Mechanism - # Upscaling Part (Decoder) - g_conv4, att4 = self.attentionblock4(conv4, gating) - up4 = self.up_concat4(g_conv4, center) - g_conv3, att3 = self.attentionblock3(conv3, up4) - up3 = self.up_concat3(g_conv3, up4) - g_conv2, att2 = self.attentionblock2(conv2, up3) - up2 = self.up_concat2(g_conv2, up3) - up1 = self.up_concat1(conv1, up2) - - # Deep Supervision - dsv4 = self.dsv4(up4) - dsv3 = self.dsv3(up3) - dsv2 = self.dsv2(up2) - dsv1 = self.dsv1(up1) - final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) - - return final - - - @staticmethod - def apply_argmax_softmax(pred): - log_p = F.softmax(pred, dim=1) - - return log_p - - -class MultiAttentionBlock(nn.Module): - def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): - super(MultiAttentionBlock, self).__init__() - self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor= sub_sample_factor) - self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, - inter_channels=inter_size, mode=nonlocal_mode, - sub_sample_factor=sub_sample_factor) - self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), - nn.BatchNorm3d(in_size), - nn.ReLU(inplace=True) - ) - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, input, gating_signal): - gate_1, attention_1 = self.gate_block_1(input, gating_signal) - gate_2, attention_2 = self.gate_block_2(input, gating_signal) - - return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) - - -if __name__ == '__main__': - model = Attention_UNet3D() \ No newline at end of file diff --git a/libs/models/BaseModelClass.py b/libs/models/BaseModelClass.py deleted file mode 100644 index 58a6428..0000000 --- a/libs/models/BaseModelClass.py +++ /dev/null @@ -1,135 +0,0 @@ -""" -Implementation of BaseModel taken and modified from here -https://github.com/kwotsin/mimicry/blob/master/torch_mimicry/nets/basemodel/basemodel.py -""" - -import os -from abc import ABC, abstractmethod -import torch -import torch.nn as nn - - -class BaseModel(nn.Module, ABC): - r""" - BaseModel with basic functionalities for checkpointing and restoration. - """ - - def __init__(self): - super().__init__() - self.best_loss = 1000000 - - @abstractmethod - def forward(self, x): - pass - - @abstractmethod - def test(self): - """ - To be implemented by the subclass so that - models can perform a forward propagation - :return: - """ - pass - - @property - def device(self): - return next(self.parameters()).device - - def restore_checkpoint(self, ckpt_file, optimizer=None): - r""" - Restores checkpoint from a pth file and restores optimizer state. - - Args: - ckpt_file (str): A PyTorch pth file containing model weights. - optimizer (Optimizer): A vanilla optimizer to have its state restored from. - - Returns: - int: Global step variable where the model was last checkpointed. - """ - if not ckpt_file: - raise ValueError("No checkpoint file to be restored.") - - try: - ckpt_dict = torch.load(ckpt_file) - except RuntimeError: - ckpt_dict = torch.load(ckpt_file, map_location=lambda storage, loc: storage) - # Restore model weights - self.load_state_dict(ckpt_dict['model_state_dict']) - - # Restore optimizer status if existing. Evaluation doesn't need this - # TODO return optimizer????? - if optimizer: - optimizer.load_state_dict(ckpt_dict['optimizer_state_dict']) - - # Return global step - return ckpt_dict['epoch'] - - def save_checkpoint(self, - directory, - epoch, loss, - optimizer=None, - name=None): - r""" - Saves checkpoint at a certain global step during training. Optimizer state - is also saved together. - - Args: - directory (str): Path to save checkpoint to. - epoch (int): The training. epoch - optimizer (Optimizer): Optimizer state to be saved concurrently. - name (str): The name to save the checkpoint file as. - - Returns: - None - """ - # Create directory to save to - if not os.path.exists(directory): - os.makedirs(directory) - - # Build checkpoint dict to save. - ckpt_dict = { - 'model_state_dict': - self.state_dict(), - 'optimizer_state_dict': - optimizer.state_dict() if optimizer is not None else None, - 'epoch': - epoch - } - - # Save the file with specific name - if name is None: - name = "{}_{}_epoch.pth".format( - os.path.basename(directory), # netD or netG - 'last') - - torch.save(ckpt_dict, os.path.join(directory, name)) - if self.best_loss > loss: - self.best_loss = loss - name = "{}_BEST.pth".format( - os.path.basename(directory)) - torch.save(ckpt_dict, os.path.join(directory, name)) - - def count_params(self): - r""" - Computes the number of parameters in this model. - - Args: None - - Returns: - int: Total number of weight parameters for this model. - int: Total number of trainable parameters for this model. - - """ - num_total_params = sum(p.numel() for p in self.parameters()) - num_trainable_params = sum(p.numel() for p in self.parameters() - if p.requires_grad) - - return num_total_params, num_trainable_params - - def inference(self, input_tensor): - self.eval() - with torch.no_grad(): - output = self.forward(input_tensor) - if isinstance(output, tuple): - output = output[0] - return output.cpu().detach() \ No newline at end of file diff --git a/libs/models/ER_Net.py b/libs/models/ER_Net.py deleted file mode 100644 index dbd86ac..0000000 --- a/libs/models/ER_Net.py +++ /dev/null @@ -1,214 +0,0 @@ -import torch -import torch.nn as nn -from functools import partial - -import torch.nn.functional as F - -nonlinearity = partial(F.relu, inplace=True) - - - -def downsample(): - return nn.MaxPool3d(kernel_size=2, stride=2) -def deconv(in_channels, out_channels): - return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2) - -def initialize_weights(*models): - for model in models: - for m in model.modules(): - if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear): - nn.init.kaiming_normal(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm3d): - m.weight.data.fill_(1) - m.bias.data.zero_() - - -class ResDecoder(nn.Module): - def __init__(self, in_channels): - super(ResDecoder, self).__init__() - self.conv1 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm3d(in_channels) - self.conv2 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm3d(in_channels) - self.relu = nn.ReLU(inplace=False) - self.conv1x1 = nn.Conv3d(in_channels, in_channels, kernel_size=1) - - def forward(self, x): - residual = self.conv1x1(x) - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out += residual - out = self.relu(out) - return out -class SFConv(nn.Module): - def __init__(self, features, M=2, r=4, L=32): - """ Constructor - Args: - features: input channel dimensionality. - WH: input spatial dimensionality, used for GAP kernel size. - M: the number of branchs. - G: num of convolution groups. - r: the radio for compute d, the length of z. - stride: stride, default 1. - L: the minimum dim of the vector z in paper, default 32. - """ - super(SFConv, self).__init__() - d = max(int(features / r), L) - self.M = M - self.features = features - # self.convs = nn.ModuleList([]) - # for i in range(M): - # self.convs.append(nn.Sequential( - # nn.Conv2d(features, features, kernel_size=3 + i * 2, stride=stride, padding=1 + i, groups=G), - # nn.BatchNorm2d(features), - # nn.ReLU(inplace=False) - # )) - # self.gap = nn.AvgPool2d(int(WH/stride)) - self.fc = nn.Linear(features, d) - self.fcs = nn.ModuleList([]) - for i in range(M): - self.fcs.append( - nn.Linear(d, features) - ) - self.softmax = nn.Softmax(dim=1) - def forward(self, x1, x2): - # for i, conv in enumerate(self.convs): - # fea = conv(x).unsqueeze_(dim=1) - # if i == 0: - # feas = fea - # else: - # feas = torch.cat([feas, fea], dim=1) - feas = torch.cat((x1.unsqueeze_(dim=1), x2.unsqueeze_(dim=1)), dim=1) - fea_U = torch.sum(feas, dim=1) - # fea_s = self.gap(fea_U).squeeze_() - fea_s = fea_U.mean(-1).mean(-1).mean((-1)) - fea_z = self.fc(fea_s) - for i, fc in enumerate(self.fcs): - vector = fc(fea_z).unsqueeze_(dim=1) - if i == 0: - attention_vectors = vector - else: - attention_vectors = torch.cat([attention_vectors, vector], dim=1) - attention_vectors = self.softmax(attention_vectors) - attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - fea_v = (feas * attention_vectors).sum(dim=1) - return fea_v - - - -class SF_Decoder(nn.Module): - def __init__(self, out_channels): - super(SF_Decoder, self).__init__() - self.conv1 = SFConv(out_channels) - self.bn1 = nn.BatchNorm3d(out_channels) - # self.conv2 = nn.Conv3d(out_channels, out_channels // 2, kernel_size=3, padding=1) - # self.bn2 = nn.BatchNorm3d(out_channels // 2) - self.relu = nn.ReLU(inplace=False) - self.ResDecoder = ResDecoder(out_channels) - # self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) - - def forward(self, x1, x2): - # residual = self.conv1x1(x) - out = self.relu(self.bn1(self.conv1(x1, x2))) - out = self.ResDecoder(out) - - # out = self.relu(self.bn2(self.conv2(out))) - # out += residual - # out = self.relu(out) - return out - -class ResEncoder(nn.Module): - def __init__(self, in_channels, out_channels): - super(ResEncoder, self).__init__() - self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm3d(out_channels) - self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm3d(out_channels) - self.relu = nn.ReLU(inplace=False) - self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - residual = self.conv1x1(x) - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out += residual - out = self.relu(out) - return out -class ER_Net(nn.Module): - def __init__(self, classes, channels): - # def __init__(self): - - super(ER_Net, self).__init__() - self.encoder1 = ResEncoder(channels, 32) - self.encoder2 = ResEncoder(32, 64) - self.encoder3 = ResEncoder(64, 128) - self.bridge = ResEncoder(128, 256) - - self.conv1_1 = nn.Conv3d(256, 1, kernel_size=1) - self.conv2_2 = nn.Conv3d(128, 1, kernel_size=1) - self.conv3_3 = nn.Conv3d(64, 1, kernel_size=1) - - self.convTrans1 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2) - self.convTrans2 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2) - self.convTrans3 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2) - - self.decoder3 = SF_Decoder(128) - self.decoder2 = SF_Decoder(64) - self.decoder1 = SF_Decoder(32) - self.down = downsample() - self.up3 = deconv(256, 128) - self.up2 = deconv(128, 64) - self.up1 = deconv(64, 32) - self.final = nn.Conv3d(32, classes, kernel_size=1, padding=0) - initialize_weights(self) - - def forward(self, x): - enc1 = self.encoder1(x) - down1 = self.down(enc1) - - enc2 = self.encoder2(down1) - down2 = self.down(enc2) - - con3_3 = self.conv3_3(enc2) - convTrans3 = self.convTrans3(con3_3) - x3 = -1 * (torch.sigmoid(convTrans3)) + 1 - x3 = x3.expand(-1, 32, -1, -1, -1).mul(enc1) - x3 = x3 + enc1 - - enc3 = self.encoder3(down2) - down3 = self.down(enc3) - - con2_2 = self.conv2_2(enc3) - convTrans2 = self.convTrans2(con2_2) - x2 = -1 * (torch.sigmoid(convTrans2)) + 1 - x2 = x2.expand(-1, 64, -1, -1, -1).mul(enc2) - x2 = x2 + enc2 - - bridge = self.bridge(down3) - - conv1_1 = self.conv1_1(bridge) - convTrans1 = self.convTrans1(conv1_1) - - x = -1 * (torch.sigmoid(convTrans1)) + 1 - x = x.expand(-1, 128, -1, -1, -1).mul(enc3) - x = x + enc3 - - up3 = self.up3(bridge) - # up3 = SKII_Decoder(up3,) - - # up3 = torch.cat((up3, x), dim=1) - dec3 = self.decoder3(up3, x) - - up2 = self.up2(dec3) - # up2 = torch.cat((up2, x2), dim=1) - dec2 = self.decoder2(up2, x2) - - up1 = self.up1(dec2) - # up1 = torch.cat((up1, x3), dim=1) - dec1 = self.decoder1(up1, x3) - - final = self.final(dec1) - # final = F.sigmoid(final) - return final \ No newline at end of file diff --git a/libs/models/HighResNet3D.py b/libs/models/HighResNet3D.py deleted file mode 100644 index ff706a3..0000000 --- a/libs/models/HighResNet3D.py +++ /dev/null @@ -1,237 +0,0 @@ -import torch -import torch.nn as nn -from .BaseModelClass import BaseModel - -""" -Implementation based on the paper: -https://arxiv.org/pdf/1707.01992.pdf -""" - - -class ConvInit(nn.Module): - def __init__(self, in_channels): - super(ConvInit, self).__init__() - self.num_features = 16 - self.in_channels = in_channels - - self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=3, padding=1) - bn1 = torch.nn.BatchNorm3d(self.num_features) - relu1 = nn.ReLU() - - self.norm = nn.Sequential(bn1, relu1) - - def forward(self, x): - y1 = self.conv1(x) - y2 = self.norm(y1) - - return y1, y2 - - -class ConvRed(nn.Module): - def __init__(self, in_channels): - super(ConvRed, self).__init__() - self.num_features = 16 - self.in_channels = in_channels - - bn1 = torch.nn.BatchNorm3d(self.num_features) - relu1 = nn.ReLU() - conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=3, padding=1) - self.conv_red = nn.Sequential(bn1, relu1, conv1) - - def forward(self, x): - return self.conv_red(x) - - -class DilatedConv2(nn.Module): - def __init__(self, in_channels): - super(DilatedConv2, self).__init__() - self.num_features = 32 - self.in_channels = in_channels - bn1 = torch.nn.BatchNorm3d(self.in_channels) - relu1 = nn.ReLU() - conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=3, padding=2, dilation=2) - - self.conv_dil = nn.Sequential(bn1, relu1, conv1) - - def forward(self, x): - return self.conv_dil(x) - - -class DilatedConv4(nn.Module): - def __init__(self, in_channels): - super(DilatedConv4, self).__init__() - self.num_features = 48 - self.in_channels = in_channels - - bn1 = torch.nn.BatchNorm3d(self.in_channels) - relu1 = nn.ReLU() - conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=3, padding=4, dilation=4) - - self.conv_dil = nn.Sequential(bn1, relu1, conv1) - - def forward(self, x): - return self.conv_dil(x) - - -class Conv1x1x1(nn.Module): - def __init__(self, in_channels, classes): - super(Conv1x1x1, self).__init__() - self.num_features = classes - self.in_channels = in_channels - - bn1 = torch.nn.BatchNorm3d(in_channels) - relu1 = nn.ReLU() - conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=1) - - self.conv_dil = nn.Sequential(bn1, relu1, conv1) - - def forward(self, x): - return self.conv_dil(x) - - -class HighResNet3D(BaseModel): - def __init__(self, in_channels=1, classes=4, shortcut_type="A", dropout_layer=True): - super(HighResNet3D, self).__init__() - self.in_channels = in_channels - self.shortcut_type = shortcut_type - self.classes = classes - self.init_channels = 16 - self.red_channels = 16 - self.dil2_channels = 32 - self.dil4_channels = 48 - self.conv_out_channels = 80 - - if self.shortcut_type == "B": - self.res_pad_1 = Conv1x1x1(self.red_channels, self.dil2_channels) - self.res_pad_2 = Conv1x1x1(self.dil2_channels, self.dil4_channels) - - self.conv_init = ConvInit(in_channels) - - self.red_blocks1 = self.create_red(self.init_channels) - self.red_blocks2 = self.create_red(self.red_channels) - self.red_blocks3 = self.create_red(self.red_channels) - - self.dil2block1 = self.create_dil2(self.red_channels) - self.dil2block2 = self.create_dil2(self.dil2_channels) - self.dil2block3 = self.create_dil2(self.dil2_channels) - - self.dil4block1 = self.create_dil4(self.dil2_channels) - self.dil4block2 = self.create_dil4(self.dil4_channels) - self.dil4block3 = self.create_dil4(self.dil4_channels) - - if dropout_layer: - conv_out = nn.Conv3d(self.dil4_channels, self.conv_out_channels, kernel_size=1) - drop3d = nn.Dropout3d() - conv1x1x1 = Conv1x1x1(self.conv_out_channels, self.classes) - self.conv_out = nn.Sequential(conv_out, drop3d, conv1x1x1) - else: - self.conv_out = Conv1x1x1(self.dil4_channels, self.classes) - - def shortcut_pad(self, x, desired_channels): - if self.shortcut_type == 'A': - batch_size, channels, dim0, dim1, dim2 = x.shape - extra_channels = desired_channels - channels - zero_channels = int(extra_channels / 2) - zeros_half = x.new_zeros(batch_size, zero_channels, dim0, dim1, dim2) - y = torch.cat((zeros_half, x, zeros_half), dim=1) - elif self.shortcut_type == 'B': - if desired_channels == self.dil2_channels: - y = self.res_pad_1(x) - elif desired_channels == self.dil4_channels: - y = self.res_pad_2(x) - return y - - def create_red(self, in_channels): - conv_red_1 = ConvRed(in_channels) - conv_red_2 = ConvRed(self.red_channels) - return nn.Sequential(conv_red_1, conv_red_2) - - def create_dil2(self, in_channels): - conv_dil2_1 = DilatedConv2(in_channels) - conv_dil2_2 = DilatedConv2(self.dil2_channels) - return nn.Sequential(conv_dil2_1, conv_dil2_2) - - def create_dil4(self, in_channels): - conv_dil4_1 = DilatedConv4(in_channels) - conv_dil4_2 = DilatedConv4(self.dil4_channels) - return nn.Sequential(conv_dil4_1, conv_dil4_2) - - def red_forward(self, x): - x, x_res = self.conv_init(x) - x_red_1 = self.red_blocks1(x) - x_red_2 = self.red_blocks2(x_red_1 + x_res) - x_red_3 = self.red_blocks3(x_red_2 + x_red_1) - return x_red_3, x_red_2 - - def dilation2(self, x_red_3, x_red_2): - x_dil2_1 = self.dil2block1(x_red_3 + x_red_2) - # print(x_dil2_1.shape ,x_red_3.shape ) - - x_red_padded = self.shortcut_pad(x_red_3, self.dil2_channels) - - x_dil2_2 = self.dil2block2(x_dil2_1 + x_red_padded) - x_dil2_3 = self.dil2block3(x_dil2_2 + x_dil2_1) - return x_dil2_3, x_dil2_2 - - def dilation4(self, x_dil2_3, x_dil2_2): - x_dil4_1 = self.dil4block1(x_dil2_3 + x_dil2_2) - x_dil2_padded = self.shortcut_pad(x_dil2_3, self.dil4_channels) - x_dil4_2 = self.dil4block2(x_dil4_1 + x_dil2_padded) - x_dil4_3 = self.dil4block3(x_dil4_2 + x_dil4_1) - return x_dil4_3 + x_dil4_2 - - def forward(self, x): - x_red_3, x_red_2 = self.red_forward(x) - x_dil2_3, x_dil2_2 = self.dilation2(x_red_3, x_red_2) - x_dil4 = self.dilation4(x_dil2_3, x_dil2_2) - y = self.conv_out(x_dil4) - - # y = interp_conv1.. conv2...(y) - # y1 = F.nn.Upsample..(y) - - return y - # return torch.sigmoid(y) - - def test(self): - k = 32 - x = torch.rand(1, self.in_channels, 256, k, k) - pred = self.forward(x) - target = torch.rand(1, self.classes, 256, k, k) - assert target.shape == pred.shape - print('High3DResnet ok!') - - -def test_all_modules(): - k = 128 - a = torch.rand(1, 16, 160, k, k) - m1 = ConvInit(in_channels=16) - y, _ = m1(a) - assert y.shape == a.shape, print(y.shape) - print("ConvInit OK") - - m2 = ConvRed(in_channels=16) - y = m2(a) - assert y.shape == a.shape, print(y.shape) - print("ConvRed OK") - - a = torch.rand(1, 32, 256, k, k) - m3 = DilatedConv2(in_channels=32) - y = m3(a) - assert y.shape == a.shape, print(y.shape) - print("DilatedConv2 OK") - - a = torch.rand(1, 48, 256, k, k) - m4 = DilatedConv4(in_channels=48) - y = m4(a) - assert y.shape == a.shape, print(y.shape) - print("DilatedConv4 OK") - - m4 = Conv1x1x1(in_channels=48, classes=33) - y = m4(a) - print(y.shape) - - -if __name__ == '__main__': - test_all_modules() - model = HighResNet3D(in_channels=1, classes=33) - model.test() \ No newline at end of file diff --git a/libs/models/ModelFactory.py b/libs/models/ModelFactory.py deleted file mode 100644 index 6bd8bfd..0000000 --- a/libs/models/ModelFactory.py +++ /dev/null @@ -1,34 +0,0 @@ -from torch import nn - - -class ModelFactory(nn.Module): - def __init__(self, model_name, num_classes, in_ch): - super(ModelFactory, self).__init__() - self.model_name = model_name - self.num_classes = num_classes - self.in_ch = in_ch - - def get(self): - if self.model_name == 'AttentionUnet3D': - from libs.models.AttentionUnet3D import Attention_UNet3D - return Attention_UNet3D(n_classes=self.num_classes, in_channels=self.in_ch) - elif self.model_name == 'VNetLight': - from libs.models.VNet import VNetLight - return VNetLight(n_classes=self.num_classes, in_channels=self.in_ch) - elif self.model_name == 'HighResNet3D': - from libs.models.HighResNet3D import HighResNet3D - return HighResNet3D(classes=self.num_classes, in_channels=self.in_ch) - elif self.model_name == 'Unet3D': - from libs.models.Unet3D import UNet3D - return UNet3D(n_classes=self.num_classes, in_channels=self.in_ch) - elif self.model_name == 'MedNeXt': - from libs.models.mednextv1.create_mednext_v1 import create_mednext_v1 - return create_mednext_v1(num_input_channels = self.in_ch, num_classes = self.num_classes, model_id = 'B', kernel_size = 3, deep_supervision = True) - elif self.model_name == 'ER_Net': - from libs.models.ER_Net import ER_Net - return ER_Net(classes= self.num_classes, channels=self.in_ch) - elif self.model_name == 'nonlocalUnet3D': - from libs.models.nonlocalUnet3D import unet_nonlocal_3D - return unet_nonlocal_3D(n_classes= self.num_classes, in_channels=self.in_ch) - else: - raise ValueError(f'Model {self.model_name} not found') diff --git a/libs/models/Unet3D.py b/libs/models/Unet3D.py deleted file mode 100644 index 448ca94..0000000 --- a/libs/models/Unet3D.py +++ /dev/null @@ -1,214 +0,0 @@ -import torch.nn as nn -import torch -from .BaseModelClass import BaseModel - - -class UNet3D(BaseModel): - """ - Implementations based on the Unet3D paper: https://arxiv.org/abs/1606.06650 - """ - - def __init__(self, in_channels, n_classes, base_n_filter=8): - super(UNet3D, self).__init__() - self.in_channels = in_channels - self.n_classes = n_classes - self.base_n_filter = base_n_filter - - self.lrelu = nn.LeakyReLU() - self.dropout3d = nn.Dropout3d(p=0.6) - self.upsacle = nn.Upsample(scale_factor=2, mode='nearest') - self.softmax = nn.Softmax(dim=1) - - self.conv3d_c1_1 = nn.Conv3d(self.in_channels, self.base_n_filter, kernel_size=3, stride=1, padding=1, - bias=False) - self.conv3d_c1_2 = nn.Conv3d(self.base_n_filter, self.base_n_filter, kernel_size=3, stride=1, padding=1, - bias=False) - self.lrelu_conv_c1 = self.lrelu_conv(self.base_n_filter, self.base_n_filter) - self.inorm3d_c1 = nn.InstanceNorm3d(self.base_n_filter) - - self.conv3d_c2 = nn.Conv3d(self.base_n_filter, self.base_n_filter * 2, kernel_size=3, stride=2, padding=1, - bias=False) - self.norm_lrelu_conv_c2 = self.norm_lrelu_conv(self.base_n_filter * 2, self.base_n_filter * 2) - self.inorm3d_c2 = nn.InstanceNorm3d(self.base_n_filter * 2) - - self.conv3d_c3 = nn.Conv3d(self.base_n_filter * 2, self.base_n_filter * 4, kernel_size=3, stride=2, padding=1, - bias=False) - self.norm_lrelu_conv_c3 = self.norm_lrelu_conv(self.base_n_filter * 4, self.base_n_filter * 4) - self.inorm3d_c3 = nn.InstanceNorm3d(self.base_n_filter * 4) - - self.conv3d_c4 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 8, kernel_size=3, stride=2, padding=1, - bias=False) - self.norm_lrelu_conv_c4 = self.norm_lrelu_conv(self.base_n_filter * 8, self.base_n_filter * 8) - self.inorm3d_c4 = nn.InstanceNorm3d(self.base_n_filter * 8) - - self.conv3d_c5 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 16, kernel_size=3, stride=2, padding=1, - bias=False) - self.norm_lrelu_conv_c5 = self.norm_lrelu_conv(self.base_n_filter * 16, self.base_n_filter * 16) - self.norm_lrelu_upscale_conv_norm_lrelu_l0 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 16, - self.base_n_filter * 8) - - self.conv3d_l0 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, - bias=False) - self.inorm3d_l0 = nn.InstanceNorm3d(self.base_n_filter * 8) - - self.conv_norm_lrelu_l1 = self.conv_norm_lrelu(self.base_n_filter * 16, self.base_n_filter * 16) - self.conv3d_l1 = nn.Conv3d(self.base_n_filter * 16, self.base_n_filter * 8, kernel_size=1, stride=1, padding=0, - bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l1 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 8, - self.base_n_filter * 4) - - self.conv_norm_lrelu_l2 = self.conv_norm_lrelu(self.base_n_filter * 8, self.base_n_filter * 8) - self.conv3d_l2 = nn.Conv3d(self.base_n_filter * 8, self.base_n_filter * 4, kernel_size=1, stride=1, padding=0, - bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l2 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 4, - self.base_n_filter * 2) - - self.conv_norm_lrelu_l3 = self.conv_norm_lrelu(self.base_n_filter * 4, self.base_n_filter * 4) - self.conv3d_l3 = nn.Conv3d(self.base_n_filter * 4, self.base_n_filter * 2, kernel_size=1, stride=1, padding=0, - bias=False) - self.norm_lrelu_upscale_conv_norm_lrelu_l3 = self.norm_lrelu_upscale_conv_norm_lrelu(self.base_n_filter * 2, - self.base_n_filter) - - self.conv_norm_lrelu_l4 = self.conv_norm_lrelu(self.base_n_filter * 2, self.base_n_filter * 2) - self.conv3d_l4 = nn.Conv3d(self.base_n_filter * 2, self.n_classes, kernel_size=1, stride=1, padding=0, - bias=False) - - self.ds2_1x1_conv3d = nn.Conv3d(self.base_n_filter * 8, self.n_classes, kernel_size=1, stride=1, padding=0, - bias=False) - self.ds3_1x1_conv3d = nn.Conv3d(self.base_n_filter * 4, self.n_classes, kernel_size=1, stride=1, padding=0, - bias=False) - self.sigmoid = nn.Sigmoid() - - def conv_norm_lrelu(self, feat_in, feat_out): - return nn.Sequential( - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU()) - - def norm_lrelu_conv(self, feat_in, feat_out): - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) - - def lrelu_conv(self, feat_in, feat_out): - return nn.Sequential( - nn.LeakyReLU(), - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False)) - - def norm_lrelu_upscale_conv_norm_lrelu(self, feat_in, feat_out): - return nn.Sequential( - nn.InstanceNorm3d(feat_in), - nn.LeakyReLU(), - nn.Upsample(scale_factor=2, mode='nearest'), - # should be feat_in*2 or feat_in - nn.Conv3d(feat_in, feat_out, kernel_size=3, stride=1, padding=1, bias=False), - nn.InstanceNorm3d(feat_out), - nn.LeakyReLU()) - - def forward(self, x): - # Level 1 context pathway - out = self.conv3d_c1_1(x) - residual_1 = out - out = self.lrelu(out) - out = self.conv3d_c1_2(out) - out = self.dropout3d(out) - out = self.lrelu_conv_c1(out) - # Element Wise Summation - out += residual_1 - context_1 = self.lrelu(out) - out = self.inorm3d_c1(out) - out = self.lrelu(out) - - # Level 2 context pathway - out = self.conv3d_c2(out) - residual_2 = out - out = self.norm_lrelu_conv_c2(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c2(out) - out += residual_2 - out = self.inorm3d_c2(out) - out = self.lrelu(out) - context_2 = out - - # Level 3 context pathway - out = self.conv3d_c3(out) - residual_3 = out - out = self.norm_lrelu_conv_c3(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c3(out) - out += residual_3 - out = self.inorm3d_c3(out) - out = self.lrelu(out) - context_3 = out - - # Level 4 context pathway - out = self.conv3d_c4(out) - residual_4 = out - out = self.norm_lrelu_conv_c4(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c4(out) - out += residual_4 - out = self.inorm3d_c4(out) - out = self.lrelu(out) - context_4 = out - - # Level 5 - out = self.conv3d_c5(out) - residual_5 = out - out = self.norm_lrelu_conv_c5(out) - out = self.dropout3d(out) - out = self.norm_lrelu_conv_c5(out) - out += residual_5 - out = self.norm_lrelu_upscale_conv_norm_lrelu_l0(out) - - out = self.conv3d_l0(out) - out = self.inorm3d_l0(out) - out = self.lrelu(out) - - # Level 1 localization pathway - out = torch.cat([out, context_4], dim=1) - out = self.conv_norm_lrelu_l1(out) - out = self.conv3d_l1(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l1(out) - - # Level 2 localization pathway - # print(out.shape) - # print(context_3.shape) - out = torch.cat([out, context_3], dim=1) - out = self.conv_norm_lrelu_l2(out) - ds2 = out - out = self.conv3d_l2(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l2(out) - - # Level 3 localization pathway - out = torch.cat([out, context_2], dim=1) - out = self.conv_norm_lrelu_l3(out) - ds3 = out - out = self.conv3d_l3(out) - out = self.norm_lrelu_upscale_conv_norm_lrelu_l3(out) - - # Level 4 localization pathway - out = torch.cat([out, context_1], dim=1) - out = self.conv_norm_lrelu_l4(out) - out_pred = self.conv3d_l4(out) - - ds2_1x1_conv = self.ds2_1x1_conv3d(ds2) - ds1_ds2_sum_upscale = self.upsacle(ds2_1x1_conv) - ds3_1x1_conv = self.ds3_1x1_conv3d(ds3) - ds1_ds2_sum_upscale_ds3_sum = ds1_ds2_sum_upscale + ds3_1x1_conv - ds1_ds2_sum_upscale_ds3_sum_upscale = self.upsacle(ds1_ds2_sum_upscale_ds3_sum) - - out = out_pred + ds1_ds2_sum_upscale_ds3_sum_upscale - seg_layer = out - return seg_layer - - def test(self, device='cpu'): - input_tensor = torch.rand(1, 2, 32, 32, 32) - ideal_out = torch.rand(1, self.n_classes, 32, 32, 32) - out = self.forward(input_tensor) - assert ideal_out.shape == out.shape - #summary(self.to(torch.device(device)), (2, 32, 32, 32), device='cpu') - # import torchsummaryX - # torchsummaryX.summary(self, input_tensor.to(device)) - print("Unet3D test is complete") diff --git a/libs/models/VNet.py b/libs/models/VNet.py deleted file mode 100644 index d3ad6b2..0000000 --- a/libs/models/VNet.py +++ /dev/null @@ -1,251 +0,0 @@ -import torch.nn as nn -import torch -from .BaseModelClass import BaseModel - -""" -Implementation of this model is borrowed and modified -(to support multi-channels and latest pytorch version) -from here: -https://github.com/Dawn90/V-Net.pytorch -""" - - -def passthrough(x, **kwargs): - return x - - -def ELUCons(elu, nchan): - if elu: - return nn.ELU(inplace=True) - else: - return nn.PReLU(nchan) - - -class LUConv(nn.Module): - def __init__(self, nchan, elu): - super(LUConv, self).__init__() - self.relu1 = ELUCons(elu, nchan) - self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) - - self.bn1 = torch.nn.BatchNorm3d(nchan) - - def forward(self, x): - out = self.relu1(self.bn1(self.conv1(x))) - return out - - -def _make_nConv(nchan, depth, elu): - layers = [] - for _ in range(depth): - layers.append(LUConv(nchan, elu)) - return nn.Sequential(*layers) - - -class InputTransition(nn.Module): - def __init__(self, in_channels, elu): - super(InputTransition, self).__init__() - self.num_features = 16 - self.in_channels = in_channels - - self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2) - - self.bn1 = torch.nn.BatchNorm3d(self.num_features) - - self.relu1 = ELUCons(elu, self.num_features) - - def forward(self, x): - out = self.conv1(x) - repeat_rate = int(self.num_features / self.in_channels) - out = self.bn1(out) - x16 = x.repeat(1, repeat_rate, 1, 1, 1) - return self.relu1(torch.add(out, x16)) - - -class DownTransition(nn.Module): - def __init__(self, inChans, nConvs, elu, dropout=False): - super(DownTransition, self).__init__() - outChans = 2 * inChans - self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2) - self.bn1 = torch.nn.BatchNorm3d(outChans) - - self.do1 = passthrough - self.relu1 = ELUCons(elu, outChans) - self.relu2 = ELUCons(elu, outChans) - if dropout: - self.do1 = nn.Dropout3d() - self.ops = _make_nConv(outChans, nConvs, elu) - - def forward(self, x): - down = self.relu1(self.bn1(self.down_conv(x))) - out = self.do1(down) - out = self.ops(out) - out = self.relu2(torch.add(out, down)) - return out - - -class UpTransition(nn.Module): - def __init__(self, inChans, outChans, nConvs, elu, dropout=False): - super(UpTransition, self).__init__() - self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2) - - self.bn1 = torch.nn.BatchNorm3d(outChans // 2) - self.do1 = passthrough - self.do2 = nn.Dropout3d() - self.relu1 = ELUCons(elu, outChans // 2) - self.relu2 = ELUCons(elu, outChans) - if dropout: - self.do1 = nn.Dropout3d() - self.ops = _make_nConv(outChans, nConvs, elu) - - def forward(self, x, skipx): - out = self.do1(x) - skipxdo = self.do2(skipx) - out = self.relu1(self.bn1(self.up_conv(out))) - xcat = torch.cat((out, skipxdo), 1) - out = self.ops(xcat) - out = self.relu2(torch.add(out, xcat)) - return out - - -class OutputTransition(nn.Module): - def __init__(self, in_channels, classes, elu): - super(OutputTransition, self).__init__() - self.classes = classes - self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2) - self.bn1 = torch.nn.BatchNorm3d(classes) - - self.conv2 = nn.Conv3d(classes, classes, kernel_size=1) - self.relu1 = ELUCons(elu, classes) - - def forward(self, x): - # convolve 32 down to channels as the desired classes - out = self.relu1(self.bn1(self.conv1(x))) - out = self.conv2(out) - return out - - -class VNet(BaseModel): - """ - Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797 - """ - - def __init__(self, elu=True, in_channels=1, classes=4): - super(VNet, self).__init__() - self.classes = classes - self.in_channels = in_channels - - self.in_tr = InputTransition(in_channels, elu=elu) - self.down_tr32 = DownTransition(16, 1, elu) - self.down_tr64 = DownTransition(32, 2, elu) - self.down_tr128 = DownTransition(64, 3, elu, dropout=True) - self.down_tr256 = DownTransition(128, 2, elu, dropout=True) - self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True) - self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True) - self.up_tr64 = UpTransition(128, 64, 1, elu) - self.up_tr32 = UpTransition(64, 32, 1, elu) - self.out_tr = OutputTransition(32, classes, elu) - - def forward(self, x): - out16 = self.in_tr(x) - out32 = self.down_tr32(out16) - out64 = self.down_tr64(out32) - out128 = self.down_tr128(out64) - out256 = self.down_tr256(out128) - out = self.up_tr256(out256, out128) - out = self.up_tr128(out, out64) - out = self.up_tr64(out, out32) - out = self.up_tr32(out, out16) - out = self.out_tr(out) - return out - - def test(self,device='cpu'): - input_tensor = torch.rand(1, self.in_channels, 32, 32, 32) - ideal_out = torch.rand(1, self.classes, 32, 32, 32) - out = self.forward(input_tensor) - assert ideal_out.shape == out.shape - summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device) - # import torchsummaryX - # torchsummaryX.summary(self, input_tensor.to(device)) - print("Vnet test is complete") - - -class VNetLight(BaseModel): - """ - A lighter version of Vnet that skips down_tr256 and up_tr256 in oreder to reduce time and space complexity - """ - - def __init__(self, elu=True, in_channels=1, n_classes=4): - super(VNetLight, self).__init__() - self.classes = n_classes - self.in_channels = in_channels - - self.in_tr = InputTransition(in_channels, elu) - self.down_tr32 = DownTransition(16, 1, elu) - self.down_tr64 = DownTransition(32, 2, elu) - self.down_tr128 = DownTransition(64, 3, elu, dropout=True) - self.up_tr128 = UpTransition(128, 128, 2, elu, dropout=True) - self.up_tr64 = UpTransition(128, 64, 1, elu) - self.up_tr32 = UpTransition(64, 32, 1, elu) - self.out_tr = OutputTransition(32, n_classes, elu) - - def forward(self, x): - out16 = self.in_tr(x) - out32 = self.down_tr32(out16) - out64 = self.down_tr64(out32) - out128 = self.down_tr128(out64) - out = self.up_tr128(out128, out64) - out = self.up_tr64(out, out32) - out = self.up_tr32(out, out16) - out = self.out_tr(out) - return out - - def test(self,device='cpu'): - input_tensor = torch.rand(1, self.in_channels, 32, 32, 32) - ideal_out = torch.rand(1, self.classes, 32, 32, 32) - out = self.forward(input_tensor) - assert ideal_out.shape == out.shape - summary(self.to(torch.device(device)), (self.in_channels, 32, 32, 32),device=device) - # import torchsummaryX - # torchsummaryX.summary(self, input_tensor.to(device)) - - print("Vnet light test is complete") - -class VNetLight2(BaseModel): - """ - A lighter version of Vnet that skips down_tr256 and up_tr256 in oreder to reduce time and space complexity - """ - - def __init__(self, elu=True, in_channels=1, classes=4): - super(VNetLight2, self).__init__() - self.classes = classes - self.in_channels = in_channels - - self.in_tr = InputTransition(in_channels, elu) - self.down_tr32 = DownTransition(8, 1, elu) - self.down_tr64 = DownTransition(16, 2, elu) - self.down_tr128 = DownTransition(32, 3, elu, dropout=True) - self.up_tr128 = UpTransition(64, 64, 2, elu, dropout=True) - self.up_tr64 = UpTransition(64, 32, 1, elu) - self.up_tr32 = UpTransition(32, 16, 1, elu) - self.out_tr = OutputTransition(16, classes, elu) - - def forward(self, x): - out16 = self.in_tr(x) - out32 = self.down_tr32(out16) - out64 = self.down_tr64(out32) - out128 = self.down_tr128(out64) - out = self.up_tr128(out128, out64) - out = self.up_tr64(out, out32) - out = self.up_tr32(out, out16) - out = self.out_tr(out) - return out - - def test(self, device='cpu'): - input_tensor = torch.rand(1, self.in_channels, 180, 96, 96) - ideal_out = torch.rand(1, self.classes, 180, 96, 96) - out = self.forward(input_tensor) - assert ideal_out.shape == out.shape - print("Vnet light test is complete") - -# m = VNet(in_channels=1,num_classes=2) -# m.test() \ No newline at end of file diff --git a/libs/models/mednextv1/MedNextV1.py b/libs/models/mednextv1/MedNextV1.py deleted file mode 100644 index 396db87..0000000 --- a/libs/models/mednextv1/MedNextV1.py +++ /dev/null @@ -1,388 +0,0 @@ -import torch -import torch.nn as nn -import torch.utils.checkpoint as checkpoint -from libs.models.mednextv1.blocks import * - -class MedNeXt(nn.Module): - - def __init__(self, - in_channels: int, - n_channels: int, - n_classes: int, - exp_r: int = 4, # Expansion ratio as in Swin Transformers - kernel_size: int = 7, # Ofcourse can test kernel_size - enc_kernel_size: int = None, - dec_kernel_size: int = None, - deep_supervision: bool = False, # Can be used to test deep supervision - do_res: bool = False, # Can be used to individually test residual connection - do_res_up_down: bool = False, # Additional 'res' connection on up and down convs - checkpoint_style: bool = None, # Either inside block or outside block - block_counts: list = [2,2,2,2,2,2,2,2,2], # Can be used to test staging ratio: - # [3,3,9,3] in Swin as opposed to [2,2,2,2,2] in nnUNet - norm_type = 'group', - ): - - super().__init__() - - self.do_ds = deep_supervision - assert checkpoint_style in [None, 'outside_block'] - self.inside_block_checkpointing = False - self.outside_block_checkpointing = False - if checkpoint_style == 'outside_block': - self.outside_block_checkpointing = True - - if kernel_size is not None: - enc_kernel_size = kernel_size - dec_kernel_size = kernel_size - - self.stem = nn.Conv3d(in_channels, n_channels, kernel_size=1) - if type(exp_r) == int: - exp_r = [exp_r for i in range(len(block_counts))] - - self.enc_block_0 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels, - out_channels=n_channels, - exp_r=exp_r[0], - kernel_size=enc_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[0])] - ) - - self.down_0 = MedNeXtDownBlock( - in_channels=n_channels, - out_channels=2*n_channels, - exp_r=exp_r[1], - kernel_size=enc_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.enc_block_1 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*2, - out_channels=n_channels*2, - exp_r=exp_r[1], - kernel_size=enc_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[1])] - ) - - self.down_1 = MedNeXtDownBlock( - in_channels=2*n_channels, - out_channels=4*n_channels, - exp_r=exp_r[2], - kernel_size=enc_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.enc_block_2 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*4, - out_channels=n_channels*4, - exp_r=exp_r[2], - kernel_size=enc_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[2])] - ) - - self.down_2 = MedNeXtDownBlock( - in_channels=4*n_channels, - out_channels=8*n_channels, - exp_r=exp_r[3], - kernel_size=enc_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.enc_block_3 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*8, - out_channels=n_channels*8, - exp_r=exp_r[3], - kernel_size=enc_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[3])] - ) - - self.down_3 = MedNeXtDownBlock( - in_channels=8*n_channels, - out_channels=16*n_channels, - exp_r=exp_r[4], - kernel_size=enc_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.bottleneck = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*16, - out_channels=n_channels*16, - exp_r=exp_r[4], - kernel_size=dec_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[4])] - ) - - self.up_3 = MedNeXtUpBlock( - in_channels=16*n_channels, - out_channels=8*n_channels, - exp_r=exp_r[5], - kernel_size=dec_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.dec_block_3 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*8, - out_channels=n_channels*8, - exp_r=exp_r[5], - kernel_size=dec_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[5])] - ) - - self.up_2 = MedNeXtUpBlock( - in_channels=8*n_channels, - out_channels=4*n_channels, - exp_r=exp_r[6], - kernel_size=dec_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.dec_block_2 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*4, - out_channels=n_channels*4, - exp_r=exp_r[6], - kernel_size=dec_kernel_size, - do_res=do_res, - norm_type=norm_type, - ) - for i in range(block_counts[6])] - ) - - self.up_1 = MedNeXtUpBlock( - in_channels=4*n_channels, - out_channels=2*n_channels, - exp_r=exp_r[7], - kernel_size=dec_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type, - ) - - self.dec_block_1 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels*2, - out_channels=n_channels*2, - exp_r=exp_r[7], - kernel_size=dec_kernel_size, - do_res=do_res, - norm_type=norm_type - ) - for i in range(block_counts[7])] - ) - - self.up_0 = MedNeXtUpBlock( - in_channels=2*n_channels, - out_channels=n_channels, - exp_r=exp_r[8], - kernel_size=dec_kernel_size, - do_res=do_res_up_down, - norm_type=norm_type - ) - - self.dec_block_0 = nn.Sequential(*[ - MedNeXtBlock( - in_channels=n_channels, - out_channels=n_channels, - exp_r=exp_r[8], - kernel_size=dec_kernel_size, - do_res=do_res, - norm_type=norm_type - ) - for i in range(block_counts[8])] - ) - - self.out_0 = OutBlock(in_channels=n_channels, n_classes=n_classes) - - # Used to fix PyTorch checkpointing bug - self.dummy_tensor = nn.Parameter(torch.tensor([1.]), requires_grad=True) - - if deep_supervision: - self.out_1 = OutBlock(in_channels=n_channels*2, n_classes=n_classes) - self.out_2 = OutBlock(in_channels=n_channels*4, n_classes=n_classes) - self.out_3 = OutBlock(in_channels=n_channels*8, n_classes=n_classes) - self.out_4 = OutBlock(in_channels=n_channels*16, n_classes=n_classes) - - self.block_counts = block_counts - - - def iterative_checkpoint(self, sequential_block, x): - """ - This simply forwards x through each block of the sequential_block while - using gradient_checkpointing. This implementation is designed to bypass - the following issue in PyTorch's gradient checkpointing: - https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/9 - """ - for l in sequential_block: - x = checkpoint.checkpoint(l, x, self.dummy_tensor) - return x - - - def forward(self, x): - - x = self.stem(x) - if self.outside_block_checkpointing: - x_res_0 = self.iterative_checkpoint(self.enc_block_0, x) - x = checkpoint.checkpoint(self.down_0, x_res_0, self.dummy_tensor) - x_res_1 = self.iterative_checkpoint(self.enc_block_1, x) - x = checkpoint.checkpoint(self.down_1, x_res_1, self.dummy_tensor) - x_res_2 = self.iterative_checkpoint(self.enc_block_2, x) - x = checkpoint.checkpoint(self.down_2, x_res_2, self.dummy_tensor) - x_res_3 = self.iterative_checkpoint(self.enc_block_3, x) - x = checkpoint.checkpoint(self.down_3, x_res_3, self.dummy_tensor) - - x = self.iterative_checkpoint(self.bottleneck, x) - if self.do_ds: - x_ds_4 = checkpoint.checkpoint(self.out_4, x, self.dummy_tensor) - - x_up_3 = checkpoint.checkpoint(self.up_3, x, self.dummy_tensor) - dec_x = x_res_3 + x_up_3 - x = self.iterative_checkpoint(self.dec_block_3, dec_x) - if self.do_ds: - x_ds_3 = checkpoint.checkpoint(self.out_3, x, self.dummy_tensor) - del x_res_3, x_up_3 - - x_up_2 = checkpoint.checkpoint(self.up_2, x, self.dummy_tensor) - dec_x = x_res_2 + x_up_2 - x = self.iterative_checkpoint(self.dec_block_2, dec_x) - if self.do_ds: - x_ds_2 = checkpoint.checkpoint(self.out_2, x, self.dummy_tensor) - del x_res_2, x_up_2 - - x_up_1 = checkpoint.checkpoint(self.up_1, x, self.dummy_tensor) - dec_x = x_res_1 + x_up_1 - x = self.iterative_checkpoint(self.dec_block_1, dec_x) - if self.do_ds: - x_ds_1 = checkpoint.checkpoint(self.out_1, x, self.dummy_tensor) - del x_res_1, x_up_1 - - x_up_0 = checkpoint.checkpoint(self.up_0, x, self.dummy_tensor) - dec_x = x_res_0 + x_up_0 - x = self.iterative_checkpoint(self.dec_block_0, dec_x) - del x_res_0, x_up_0, dec_x - - x = checkpoint.checkpoint(self.out_0, x, self.dummy_tensor) - - else: - x_res_0 = self.enc_block_0(x) - x = self.down_0(x_res_0) - x_res_1 = self.enc_block_1(x) - x = self.down_1(x_res_1) - x_res_2 = self.enc_block_2(x) - x = self.down_2(x_res_2) - x_res_3 = self.enc_block_3(x) - x = self.down_3(x_res_3) - - x = self.bottleneck(x) - if self.do_ds: - x_ds_4 = self.out_4(x) - - x_up_3 = self.up_3(x) - dec_x = x_res_3 + x_up_3 - x = self.dec_block_3(dec_x) - - if self.do_ds: - x_ds_3 = self.out_3(x) - del x_res_3, x_up_3 - - x_up_2 = self.up_2(x) - dec_x = x_res_2 + x_up_2 - x = self.dec_block_2(dec_x) - if self.do_ds: - x_ds_2 = self.out_2(x) - del x_res_2, x_up_2 - - x_up_1 = self.up_1(x) - dec_x = x_res_1 + x_up_1 - x = self.dec_block_1(dec_x) - if self.do_ds: - x_ds_1 = self.out_1(x) - del x_res_1, x_up_1 - - x_up_0 = self.up_0(x) - dec_x = x_res_0 + x_up_0 - x = self.dec_block_0(dec_x) - del x_res_0, x_up_0, dec_x - - x = self.out_0(x) - - if self.do_ds: - return [x, x_ds_1, x_ds_2, x_ds_3, x_ds_4] - else: - return x - - -if __name__ == "__main__": - - network = MedNeXt( - in_channels = 1, - n_channels = 32, - n_classes = 13, - exp_r=[2,3,4,4,4,4,4,3,2], # Expansion ratio as in Swin Transformers - # exp_r = 2, - kernel_size=3, # Can test kernel_size - deep_supervision=True, # Can be used to test deep supervision - do_res=True, # Can be used to individually test residual connection - do_res_up_down = True, - # block_counts = [2,2,2,2,2,2,2,2,2], - block_counts = [3,4,8,8,8,8,8,4,3], - checkpoint_style = None, - - ).cuda() - - # network = MedNeXt_RegularUpDown( - # in_channels = 1, - # n_channels = 32, - # n_classes = 13, - # exp_r=[2,3,4,4,4,4,4,3,2], # Expansion ratio as in Swin Transformers - # kernel_size=3, # Can test kernel_size - # deep_supervision=True, # Can be used to test deep supervision - # do_res=True, # Can be used to individually test residual connection - # block_counts = [2,2,2,2,2,2,2,2,2], - # - # ).cuda() - - def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) - - print(count_parameters(network)) - - from fvcore.nn import FlopCountAnalysis - from fvcore.nn import parameter_count_table - - # model = ResTranUnet(img_size=128, in_channels=1, num_classes=14, dummy=False).cuda() - x = torch.zeros((1,1,64,64,64), requires_grad=False).cuda() - flops = FlopCountAnalysis(network, x) - print(flops.total()) - - with torch.no_grad(): - print(network) - x = torch.zeros((2, 1, 128, 128, 128)).cuda() - print(network(x)[0].shape) diff --git a/libs/models/mednextv1/blocks.py b/libs/models/mednextv1/blocks.py deleted file mode 100644 index 623d6df..0000000 --- a/libs/models/mednextv1/blocks.py +++ /dev/null @@ -1,214 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MedNeXtBlock(nn.Module): - - def __init__(self, - in_channels:int, - out_channels:int, - exp_r:int=4, - kernel_size:int=7, - do_res:int=True, - norm_type:str = 'group', - n_groups:int or None = None, - ): - - super().__init__() - - self.do_res = do_res - - # First convolution layer with DepthWise Convolutions - self.conv1 = nn.Conv3d( - in_channels = in_channels, - out_channels = in_channels, - kernel_size = kernel_size, - stride = 1, - padding = kernel_size//2, - groups = in_channels if n_groups is None else n_groups, - ) - - # Normalization Layer. GroupNorm is used by default. - if norm_type=='group': - self.norm = nn.GroupNorm( - num_groups=in_channels, - num_channels=in_channels - ) - elif norm_type=='layer': - self.norm = LayerNorm( - normalized_shape=in_channels, - data_format='channels_first' - ) - - # Second convolution (Expansion) layer with Conv3D 1x1x1 - self.conv2 = nn.Conv3d( - in_channels = in_channels, - out_channels = exp_r*in_channels, - kernel_size = 1, - stride = 1, - padding = 0 - ) - - # GeLU activations - self.act = nn.GELU() - - # Third convolution (Compression) layer with Conv3D 1x1x1 - self.conv3 = nn.Conv3d( - in_channels = exp_r*in_channels, - out_channels = out_channels, - kernel_size = 1, - stride = 1, - padding = 0 - ) - - - def forward(self, x, dummy_tensor=None): - - x1 = x - x1 = self.conv1(x1) - x1 = self.act(self.conv2(self.norm(x1))) - x1 = self.conv3(x1) - if self.do_res: - x1 = x + x1 - return x1 - - -class MedNeXtDownBlock(MedNeXtBlock): - - def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, - do_res=False, norm_type = 'group'): - - super().__init__(in_channels, out_channels, exp_r, kernel_size, - do_res = False, norm_type = norm_type) - - self.resample_do_res = do_res - if do_res: - self.res_conv = nn.Conv3d( - in_channels = in_channels, - out_channels = out_channels, - kernel_size = 1, - stride = 2 - ) - - self.conv1 = nn.Conv3d( - in_channels = in_channels, - out_channels = in_channels, - kernel_size = kernel_size, - stride = 2, - padding = kernel_size//2, - groups = in_channels, - ) - - def forward(self, x, dummy_tensor=None): - - x1 = super().forward(x) - - if self.resample_do_res: - res = self.res_conv(x) - x1 = x1 + res - - return x1 - - -class MedNeXtUpBlock(MedNeXtBlock): - - def __init__(self, in_channels, out_channels, exp_r=4, kernel_size=7, - do_res=False, norm_type = 'group'): - super().__init__(in_channels, out_channels, exp_r, kernel_size, - do_res=False, norm_type = norm_type) - - self.resample_do_res = do_res - if do_res: - self.res_conv = nn.ConvTranspose3d( - in_channels = in_channels, - out_channels = out_channels, - kernel_size = 1, - stride = 2 - ) - - self.conv1 = nn.ConvTranspose3d( - in_channels = in_channels, - out_channels = in_channels, - kernel_size = kernel_size, - stride = 2, - padding = kernel_size//2, - groups = in_channels, - ) - - - def forward(self, x, dummy_tensor=None): - - x1 = super().forward(x) - # Asymmetry but necessary to match shape - x1 = torch.nn.functional.pad(x1, (1,0,1,0,1,0)) - - if self.resample_do_res: - res = self.res_conv(x) - res = torch.nn.functional.pad(res, (1,0,1,0,1,0)) - x1 = x1 + res - - return x1 - - -class OutBlock(nn.Module): - - def __init__(self, in_channels, n_classes): - super().__init__() - self.conv_out = nn.Conv3d(in_channels, n_classes, kernel_size=1) - - def forward(self, x, dummy_tensor=None): - return self.conv_out(x) - - -class LayerNorm(nn.Module): - """ LayerNorm that supports two data formats: channels_last (default) or channels_first. - The ordering of the dimensions in the inputs. channels_last corresponds to inputs with - shape (batch_size, height, width, channels) while channels_first corresponds to inputs - with shape (batch_size, channels, height, width). - """ - def __init__(self, normalized_shape, eps=1e-5, data_format="channels_last"): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) # beta - self.bias = nn.Parameter(torch.zeros(normalized_shape)) # gamma - self.eps = eps - self.data_format = data_format - if self.data_format not in ["channels_last", "channels_first"]: - raise NotImplementedError - self.normalized_shape = (normalized_shape, ) - - def forward(self, x, dummy_tensor=False): - if self.data_format == "channels_last": - return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - elif self.data_format == "channels_first": - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None] - return x - - -if __name__ == "__main__": - - - # network = nnUNeXtBlock(in_channels=12, out_channels=12, do_res=False).cuda() - - # with torch.no_grad(): - # print(network) - # x = torch.zeros((2, 12, 8, 8, 8)).cuda() - # print(network(x).shape) - - # network = DownsampleBlock(in_channels=12, out_channels=24, do_res=False) - - # with torch.no_grad(): - # print(network) - # x = torch.zeros((2, 12, 128, 128, 128)) - # print(network(x).shape) - - network = MedNeXtBlock(in_channels=12, out_channels=12, do_res=True, norm_type='group').cuda() - # network = LayerNorm(normalized_shape=12, data_format='channels_last').cuda() - # network.eval() - with torch.no_grad(): - print(network) - x = torch.zeros((2, 12, 64, 64, 64)).cuda() - print(network(x).shape) diff --git a/libs/models/mednextv1/create_mednext_v1.py b/libs/models/mednextv1/create_mednext_v1.py deleted file mode 100644 index 0659d01..0000000 --- a/libs/models/mednextv1/create_mednext_v1.py +++ /dev/null @@ -1,83 +0,0 @@ -from libs.models.mednextv1.MedNextV1 import MedNeXt - -def create_mednextv1_small(num_input_channels, num_classes, kernel_size=3, ds=False): - - return MedNeXt( - in_channels = num_input_channels, - n_channels = 32, - n_classes = num_classes, - exp_r=2, - kernel_size=kernel_size, - deep_supervision=ds, - do_res=True, - do_res_up_down = True, - block_counts = [2,2,2,2,2,2,2,2,2] - ) - - -def create_mednextv1_base(num_input_channels, num_classes, kernel_size=3, ds=False): - - return MedNeXt( - in_channels = num_input_channels, - n_channels = 32, - n_classes = num_classes, - exp_r=[2,3,4,4,4,4,4,3,2], - kernel_size=kernel_size, - deep_supervision=ds, - do_res=True, - do_res_up_down = True, - block_counts = [2,2,2,2,2,2,2,2,2] - ) - - -def create_mednextv1_medium(num_input_channels, num_classes, kernel_size=3, ds=False): - - return MedNeXt( - in_channels = num_input_channels, - n_channels = 32, - n_classes = num_classes, - exp_r=[2,3,4,4,4,4,4,3,2], - kernel_size=kernel_size, - deep_supervision=ds, - do_res=True, - do_res_up_down = True, - block_counts = [3,4,4,4,4,4,4,4,3], - checkpoint_style = 'outside_block' - ) - - -def create_mednextv1_large(num_input_channels, num_classes, kernel_size=3, ds=False): - - return MedNeXt( - in_channels = num_input_channels, - n_channels = 32, - n_classes = num_classes, - exp_r=[3,4,8,8,8,8,8,4,3], - kernel_size=kernel_size, - deep_supervision=ds, - do_res=True, - do_res_up_down = True, - block_counts = [3,4,8,8,8,8,8,4,3], - checkpoint_style = 'outside_block' - ) - - -def create_mednext_v1(num_input_channels, num_classes, model_id, kernel_size=3, - deep_supervision=False): - - model_dict = { - 'S': create_mednextv1_small, - 'B': create_mednextv1_base, - 'M': create_mednextv1_medium, - 'L': create_mednextv1_large, - } - - return model_dict[model_id]( - num_input_channels, num_classes, kernel_size, deep_supervision - ) - - -if __name__ == "__main__": - - model = create_mednextv1_large(1, 3, 3, False) - print(model) \ No newline at end of file diff --git a/libs/models/nonlocalUnet3D.py b/libs/models/nonlocalUnet3D.py deleted file mode 100644 index e907c27..0000000 --- a/libs/models/nonlocalUnet3D.py +++ /dev/null @@ -1,450 +0,0 @@ -import math -import torch -import torch.nn as nn -from torch.nn import init -import torch.nn.functional as F - -class unet_nonlocal_3D(nn.Module): - - def __init__(self, feature_scale=4, n_classes=1, is_deconv=True, in_channels=1, is_batchnorm=True, - nonlocal_mode='embedded_gaussian', nonlocal_sf=4): - super(unet_nonlocal_3D, self).__init__() - self.is_deconv = is_deconv - self.in_channels = in_channels - self.is_batchnorm = is_batchnorm - self.feature_scale = feature_scale - - filters = [64, 128, 256, 512, 1024] - filters = [int(x / self.feature_scale) for x in filters] - - # downsampling - self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm) - self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 1)) - - self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm) - self.nonlocal2 = NONLocalBlock3D(in_channels=filters[1], inter_channels=filters[1] // 4, - sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) - self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 1)) - - self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm) - self.nonlocal3 = NONLocalBlock3D(in_channels=filters[2], inter_channels=filters[2] // 4, - sub_sample_factor=nonlocal_sf, mode=nonlocal_mode) - self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 1)) - - self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm) - self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 1)) - - self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm) - - # upsampling - self.up_concat4 = UnetUp3(filters[4], filters[3], self.is_deconv) - self.up_concat3 = UnetUp3(filters[3], filters[2], self.is_deconv) - self.up_concat2 = UnetUp3(filters[2], filters[1], self.is_deconv) - self.up_concat1 = UnetUp3(filters[1], filters[0], self.is_deconv) - - # final conv (without any concat) - self.final = nn.Conv3d(filters[0], n_classes, 1) - - # initialise weights - for m in self.modules(): - if isinstance(m, nn.Conv3d): - init_weights(m, init_type='kaiming') - elif isinstance(m, nn.BatchNorm3d): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - conv1 = self.conv1(inputs) - maxpool1 = self.maxpool1(conv1) - - conv2 = self.conv2(maxpool1) - nl2 = self.nonlocal2(conv2) - maxpool2 = self.maxpool2(nl2) - - conv3 = self.conv3(maxpool2) - nl3 = self.nonlocal3(conv3) - maxpool3 = self.maxpool3(nl3) - - conv4 = self.conv4(maxpool3) - maxpool4 = self.maxpool4(conv4) - - center = self.center(maxpool4) - up4 = self.up_concat4(conv4, center) - up3 = self.up_concat3(nl3, up4) - up2 = self.up_concat2(nl2, up3) - up1 = self.up_concat1(conv1, up2) - - final = self.final(up1) - - return final - - @staticmethod - def apply_argmax_softmax(pred): - log_p = F.softmax(pred, dim=1) - - return log_p - - -class UnetConv3(nn.Module): - def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)): - super(UnetConv3, self).__init__() - - if is_batchnorm: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.BatchNorm3d(out_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.BatchNorm3d(out_size), - nn.ReLU(inplace=True),) - else: - self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size), - nn.ReLU(inplace=True),) - self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size), - nn.ReLU(inplace=True),) - - # initialise the blocks - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, inputs): - outputs = self.conv1(inputs) - outputs = self.conv2(outputs) - return outputs - - -class UnetUp3(nn.Module): - def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True): - super(UnetUp3, self).__init__() - if is_deconv: - self.conv = UnetConv3(in_size, out_size, is_batchnorm) - self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0)) - else: - self.conv = UnetConv3(in_size+out_size, out_size, is_batchnorm) - self.up = nn.Upsample(scale_factor=(2, 2, 1), mode='trilinear') - - # initialise the blocks - for m in self.children(): - if m.__class__.__name__.find('UnetConv3') != -1: continue - init_weights(m, init_type='kaiming') - - def forward(self, inputs1, inputs2): - outputs2 = self.up(inputs2) - offset = outputs2.size()[2] - inputs1.size()[2] - padding = 2 * [offset // 2, offset // 2, 0] - outputs1 = F.pad(inputs1, padding) - return self.conv(torch.cat([outputs1, outputs2], 1)) - - -def init_weights(net, init_type='kaiming'): - #print('initialization method [%s]' % init_type) - if init_type == 'normal': - pass - elif init_type == 'xavier': - pass - elif init_type == 'kaiming': - net.apply(weights_init_kaiming) - elif init_type == 'orthogonal': - pass - else: - raise NotImplementedError('initialization method [%s] is not implemented' % init_type) - -def weights_init_kaiming(m): - classname = m.__class__.__name__ - #print(classname) - if classname.find('Conv') != -1: - init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('Linear') != -1: - init.kaiming_normal(m.weight.data, a=0, mode='fan_in') - elif classname.find('BatchNorm') != -1: - init.normal(m.weight.data, 1.0, 0.02) - init.constant(m.bias.data, 0.0) - - - -class _NonLocalBlockND(nn.Module): - def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded_gaussian', - sub_sample_factor=4, bn_layer=True): - super(_NonLocalBlockND, self).__init__() - - assert dimension in [1, 2, 3] - assert mode in ['embedded_gaussian', 'gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down'] - - # print('Dimension: %d, mode: %s' % (dimension, mode)) - - self.mode = mode - self.dimension = dimension - self.sub_sample_factor = sub_sample_factor if isinstance(sub_sample_factor, list) else [sub_sample_factor] - - self.in_channels = in_channels - self.inter_channels = inter_channels - - if self.inter_channels is None: - self.inter_channels = in_channels // 2 - if self.inter_channels == 0: - self.inter_channels = 1 - - if dimension == 3: - conv_nd = nn.Conv3d - max_pool = nn.MaxPool3d - bn = nn.BatchNorm3d - elif dimension == 2: - conv_nd = nn.Conv2d - max_pool = nn.MaxPool2d - bn = nn.BatchNorm2d - else: - conv_nd = nn.Conv1d - max_pool = nn.MaxPool1d - bn = nn.BatchNorm1d - - self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0) - - if bn_layer: - self.W = nn.Sequential( - conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, - kernel_size=1, stride=1, padding=0), - bn(self.in_channels) - ) - nn.init.constant(self.W[1].weight, 0) - nn.init.constant(self.W[1].bias, 0) - else: - self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, - kernel_size=1, stride=1, padding=0) - nn.init.constant(self.W.weight, 0) - nn.init.constant(self.W.bias, 0) - - self.theta = None - self.phi = None - - if mode in ['embedded_gaussian', 'dot_product', 'concatenation', 'concat_proper', 'concat_proper_down']: - self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0) - self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, - kernel_size=1, stride=1, padding=0) - - if mode in ['concatenation']: - self.wf_phi = nn.Linear(self.inter_channels, 1, bias=False) - self.wf_theta = nn.Linear(self.inter_channels, 1, bias=False) - elif mode in ['concat_proper', 'concat_proper_down']: - self.psi = nn.Conv2d(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, - padding=0, bias=True) - - if mode == 'embedded_gaussian': - self.operation_function = self._embedded_gaussian - elif mode == 'dot_product': - self.operation_function = self._dot_product - elif mode == 'gaussian': - self.operation_function = self._gaussian - elif mode == 'concatenation': - self.operation_function = self._concatenation - elif mode == 'concat_proper': - self.operation_function = self._concatenation_proper - elif mode == 'concat_proper_down': - self.operation_function = self._concatenation_proper_down - else: - raise NotImplementedError('Unknown operation function.') - - if any(ss > 1 for ss in self.sub_sample_factor): - self.g = nn.Sequential(self.g, max_pool(kernel_size=sub_sample_factor)) - if self.phi is None: - self.phi = max_pool(kernel_size=sub_sample_factor) - else: - self.phi = nn.Sequential(self.phi, max_pool(kernel_size=sub_sample_factor)) - if mode == 'concat_proper_down': - self.theta = nn.Sequential(self.theta, max_pool(kernel_size=sub_sample_factor)) - - # Initialise weights - for m in self.children(): - init_weights(m, init_type='kaiming') - - def forward(self, x): - ''' - :param x: (b, c, t, h, w) - :return: - ''' - - output = self.operation_function(x) - return output - - def _embedded_gaussian(self, x): - batch_size = x.size(0) - - # g=>(b, c, t, h, w)->(b, 0.5c, t, h, w)->(b, thw, 0.5c) - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - g_x = g_x.permute(0, 2, 1) - - # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) - # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) - # f=>(b, thw, 0.5c)dot(b, 0.5c, twh) = (b, thw, thw) - theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) - theta_x = theta_x.permute(0, 2, 1) - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) - f = torch.matmul(theta_x, phi_x) - f_div_C = F.softmax(f, dim=-1) - - # (b, thw, thw)dot(b, thw, 0.5c) = (b, thw, 0.5c)->(b, 0.5c, t, h, w)->(b, c, t, h, w) - y = torch.matmul(f_div_C, g_x) - y = y.permute(0, 2, 1).contiguous() - y = y.view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - z = W_y + x - - return z - - def _gaussian(self, x): - batch_size = x.size(0) - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - g_x = g_x.permute(0, 2, 1) - - theta_x = x.view(batch_size, self.in_channels, -1) - theta_x = theta_x.permute(0, 2, 1) - - if self.sub_sample_factor > 1: - phi_x = self.phi(x).view(batch_size, self.in_channels, -1) - else: - phi_x = x.view(batch_size, self.in_channels, -1) - - f = torch.matmul(theta_x, phi_x) - f_div_C = F.softmax(f, dim=-1) - - y = torch.matmul(f_div_C, g_x) - y = y.permute(0, 2, 1).contiguous() - y = y.view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - z = W_y + x - - return z - - def _dot_product(self, x): - batch_size = x.size(0) - - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - g_x = g_x.permute(0, 2, 1) - - theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) - theta_x = theta_x.permute(0, 2, 1) - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) - f = torch.matmul(theta_x, phi_x) - N = f.size(-1) - f_div_C = f / N - - y = torch.matmul(f_div_C, g_x) - y = y.permute(0, 2, 1).contiguous() - y = y.view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - z = W_y + x - - return z - - def _concatenation(self, x): - batch_size = x.size(0) - - # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - - # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw, 0.5c) - # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, thw/s**2, 0.5c) - theta_x = self.theta(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1).permute(0, 2, 1) - - # theta => (b, thw, 0.5c) -> (b, thw, 1) -> (b, 1, thw) -> (expand) (b, thw/s**2, thw) - # phi => (b, thw/s**2, 0.5c) -> (b, thw/s**2, 1) -> (expand) (b, thw/s**2, thw) - # f=> RELU[(b, thw/s**2, thw) + (b, thw/s**2, thw)] = (b, thw/s**2, thw) - f = self.wf_theta(theta_x).permute(0, 2, 1).repeat(1, phi_x.size(1), 1) + \ - self.wf_phi(phi_x).repeat(1, 1, theta_x.size(1)) - f = F.relu(f, inplace=True) - - # Normalise the relations - N = f.size(-1) - f_div_c = f / N - - # g(x_j) * f(x_j, x_i) - # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) - y = torch.matmul(g_x, f_div_c) - y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - z = W_y + x - - return z - - def _concatenation_proper(self, x): - batch_size = x.size(0) - - # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - - # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) - # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) - theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) - - # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) - # phi => (b, 0.5c, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) - # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) - f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ - phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) - f = F.relu(f, inplace=True) - - # psi -> W_psi^t * f -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) - f = torch.squeeze(self.psi(f), dim=1) - - # Normalise the relations - f_div_c = F.softmax(f, dim=1) - - # g(x_j) * f(x_j, x_i) - # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) - y = torch.matmul(g_x, f_div_c) - y = y.contiguous().view(batch_size, self.inter_channels, *x.size()[2:]) - W_y = self.W(y) - z = W_y + x - - return z - - def _concatenation_proper_down(self, x): - batch_size = x.size(0) - - # g=>(b, c, t, h, w)->(b, 0.5c, thw/s**2) - g_x = self.g(x).view(batch_size, self.inter_channels, -1) - - # theta=>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw) - # phi =>(b, c, t, h, w)[->(b, 0.5c, t, h, w)]->(b, 0.5c, thw/s**2) - theta_x = self.theta(x) - downsampled_size = theta_x.size() - theta_x = theta_x.view(batch_size, self.inter_channels, -1) - phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) - - # theta => (b, 0.5c, thw) -> (expand) (b, 0.5c, thw/s**2, thw) - # phi => (b, 0.5, thw/s**2) -> (expand) (b, 0.5c, thw/s**2, thw) - # f=> RELU[(b, 0.5c, thw/s**2, thw) + (b, 0.5c, thw/s**2, thw)] = (b, 0.5c, thw/s**2, thw) - f = theta_x.unsqueeze(dim=2).repeat(1,1,phi_x.size(2),1) + \ - phi_x.unsqueeze(dim=3).repeat(1,1,1,theta_x.size(2)) - f = F.relu(f, inplace=True) - - # psi -> W_psi^t * f -> (b, 0.5c, thw/s**2, thw) -> (b, 1, thw/s**2, thw) -> (b, thw/s**2, thw) - f = torch.squeeze(self.psi(f), dim=1) - - # Normalise the relations - f_div_c = F.softmax(f, dim=1) - - # g(x_j) * f(x_j, x_i) - # (b, 0.5c, thw/s**2) * (b, thw/s**2, thw) -> (b, 0.5c, thw) - y = torch.matmul(g_x, f_div_c) - y = y.contiguous().view(batch_size, self.inter_channels, *downsampled_size[2:]) - - # upsample the final featuremaps # (b,0.5c,t/s1,h/s2,w/s3) - y = F.upsample(y, size=x.size()[2:], mode='trilinear') - - # attention block output - W_y = self.W(y) - z = W_y + x - - return z - - -class NONLocalBlock3D(_NonLocalBlockND): - def __init__(self, in_channels, inter_channels=None, mode='embedded_gaussian', sub_sample_factor=2, bn_layer=True): - super(NONLocalBlock3D, self).__init__(in_channels, - inter_channels=inter_channels, - dimension=3, mode=mode, - sub_sample_factor=sub_sample_factor, - bn_layer=bn_layer) - \ No newline at end of file diff --git a/libs/optimizers/OptimizerFactory.py b/libs/optimizers/OptimizerFactory.py deleted file mode 100644 index aeac925..0000000 --- a/libs/optimizers/OptimizerFactory.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch.optim import Adam, SGD, AdamW - - -class OptimizerFactory(): - def __init__(self, name, params, lr): - super(OptimizerFactory, self).__init__() - self.name = name - self.params = params - self.lr = lr - - def get(self): - if self.name == 'Adam': - self.optimizer = Adam(params=self.params, lr=self.lr) - elif self.name == 'AdamW': - self.optimizer = AdamW(params=self.params, lr=self.lr) - elif self.name == 'SGD': - self.optimizer = SGD(params=self.params, lr=self.lr) - else: - raise ValueError(f'Unknown optimizer: {self.name}') - - self.optimizer.name = self.name - return self.optimizer - diff --git a/libs/schedulers/SchedulerFactory.py b/libs/schedulers/SchedulerFactory.py deleted file mode 100644 index 145fbf4..0000000 --- a/libs/schedulers/SchedulerFactory.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import math -from torch.optim.lr_scheduler import MultiStepLR, ReduceLROnPlateau, _LRScheduler, LambdaLR - -class SchedulerFactory(): - def __init__(self, name, optimizer, **kwargs): - super(SchedulerFactory, self).__init__() - self.name = name - self.optimizer = optimizer - self.kwargs = kwargs - - def get(self): - if self.name == 'MultiStepLR': - self.kwargs = { - 'milestones': self.kwargs.get('milestones'), - 'gamma': self.kwargs.get('gamma', 0.1), - } - scheduler = MultiStepLR(self.optimizer, **self.kwargs) - elif self.name == 'LambdaLR': - scheduler = LambdaLR(self.optimizer, lr_lambda=lambda epoch: 0.999**epoch) - - elif self.name == 'Plateau': - self.kwargs = { - 'mode': self.kwargs.get('mode', None), - 'patience': self.kwargs.get('patience', None), - 'verbose': True, - } - scheduler = ReduceLROnPlateau(self.optimizer, **self.kwargs) - elif self.name == 'SGDR': - self.kwargs = { - 'T_0':150, - 'T_mult':1, - 'eta_max':0.1, - 'T_up':10, - 'gamma':0.5, - } - scheduler = CosineAnnealingWarmUpRestarts(self.optimizer, **self.kwargs) - else: - raise ValueError(f'Unknown scheduler: {self.name}') - scheduler.name = self.name - return scheduler - - -class CosineAnnealingWarmUpRestarts(_LRScheduler): - def __init__(self, optimizer, T_0, T_mult=1, eta_max=0.1, T_up=0, gamma=1., last_epoch=-1): - if T_0 <= 0 or not isinstance(T_0, int): - raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) - if T_mult < 1 or not isinstance(T_mult, int): - raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) - if T_up < 0 or not isinstance(T_up, int): - raise ValueError("Expected positive integer T_up, but got {}".format(T_up)) - self.T_0 = T_0 - self.T_mult = T_mult - self.base_eta_max = eta_max - self.eta_max = eta_max - self.T_up = T_up - self.T_i = T_0 - self.gamma = gamma - self.cycle = 0 - self.T_cur = last_epoch - super(CosineAnnealingWarmUpRestarts, self).__init__(optimizer, last_epoch) - - def get_lr(self): - if self.T_cur == -1: - return self.base_lrs - elif self.T_cur < self.T_up: - return [(self.eta_max - base_lr)*self.T_cur / self.T_up + base_lr for base_lr in self.base_lrs] - else: - return [base_lr + (self.eta_max - base_lr) * (1 + math.cos(math.pi * (self.T_cur-self.T_up) / (self.T_i - self.T_up))) / 2 - for base_lr in self.base_lrs] - - def step(self, epoch=None): - if epoch is None: - epoch = self.last_epoch + 1 - self.T_cur = self.T_cur + 1 - if self.T_cur >= self.T_i: - self.cycle += 1 - self.T_cur = self.T_cur - self.T_i - self.T_i = (self.T_i - self.T_up) * self.T_mult + self.T_up - else: - if epoch >= self.T_0: - if self.T_mult == 1: - self.T_cur = epoch % self.T_0 - self.cycle = epoch // self.T_0 - else: - n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) - self.cycle = n - self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) - self.T_i = self.T_0 * self.T_mult ** (n) - else: - self.T_i = self.T_0 - self.T_cur = epoch - - self.eta_max = self.base_eta_max * (self.gamma**self.cycle) - self.last_epoch = math.floor(epoch) - for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): - param_group['lr'] = lr \ No newline at end of file diff --git a/preproc.py b/preproc.py deleted file mode 100644 index fe47940..0000000 --- a/preproc.py +++ /dev/null @@ -1,121 +0,0 @@ -import sys -import os -import argparse -import logging -import logging.config -import shutil -import yaml -from hashlib import shake_256 -import time -from munch import Munch, munchify, unmunchify -from utils.AugmentFactory import * -from utils.TaskFactory import * -import pandas as pd - -def _nrrd_reader(path): - raw_data, _ = nrrd.read(path) - data = torch.from_numpy(raw_data).float() - affine = torch.eye(4, requires_grad=False) # Identity matrix(단위 행렬) - return data, affine - -def timehash(): - t = time.time() - t = str(t).encode() - h = shake_256(t) - h = h.hexdigest(5) # output len: 2*5=10 - return h.upper() - -if __name__ == '__main__': - logger = logging.getLogger() - logger.setLevel(logging.INFO) - - # Parse arguments - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("-c", "--config", default="configs/preprocessing.yaml", - help="the preprocessing config file to be used to run the experiment") - arg_parser.add_argument("--verbose", action='store_true', help="Log also to stdout") - args = arg_parser.parse_args() - - - - - # check if the config files exists - if not os.path.exists(args.config): - logging.info("Config file does not exist: {}".format(args.config)) - raise SystemExit - - # Munchify the dict to access entries with both dot notation and ['name'] - logging.info(f'Loading the config file...') - preproc = yaml.load(open(args.config, "r"), yaml.FullLoader) - preproc = munchify(preproc) - - - source_dir = preproc.source_dir - save_dir = preproc.save_dir - # set the title name with timehash - title = f'preprocessing_{timehash()}' - save_dir= os.path.join(save_dir, title) - os.makedirs(save_dir, exist_ok=True) - logging.info(f'source_dir: {source_dir}') - logging.info(f'save_dir: {save_dir}') - - preproc_processing = AugFactory(preproc.preprocessing).get_transform() - copy_preprocessing_path = os.path.join(save_dir, 'preprocessing.yaml') - if args.config is not None: - shutil.copy(args.config, copy_preprocessing_path) - - transform=preproc_processing - - logging.info(f'experiment title: {preproc.experiment.name}') - - # main - patient_data_list = os.listdir(source_dir) - subjects = [] - table = [] - for patient in patient_data_list: - # generate labels - ct_data_path = os.path.join(source_dir, patient, patient + '_IMG_CT.nrrd') - # mr_data_path = os.path.join(root, patient, patient + '_IMG_MR_T1.nrrd') - label_path = os.path.join(source_dir, patient, patient + f'_{preproc.experiment.name}.seg.nrrd') - if not os.path.isfile(ct_data_path): - raise ValueError(f'Missing CT data file for patient {patient} ({ct_data_path})') - # if not os.path.isfile(mr_data_path): - # raise ValueError(f'Missing MR_TI data file for patient {patient} ({mr_data_path})') - if not os.path.isfile(label_path): - raise ValueError(f'Missing LABEL file for patient {patient} ({label_path})') - - subject_dict = { - - 'patient': patient, - 'ct': tio.ScalarImage(ct_data_path), - # 'mr': tio.ScalarImage(mr_data_path), - 'label': tio.LabelMap(label_path), - } - - # preprocessing - subject = tio.Subject(**subject_dict) - transform_subject = transform(subject) - # save - os.makedirs(os.path.join(save_dir, patient), exist_ok=True) - ct_data_path = os.path.join(save_dir, patient, patient + '_IMG_CT.nrrd') - label_path = os.path.join(save_dir, patient, patient + f'_{preproc.experiment.name}.seg.nrrd') - - transformed_ct_data = transform_subject['ct'][tio.DATA].squeeze(0).numpy() - transformed_label_data = transform_subject['label'][tio.DATA].squeeze(0).numpy() - - nrrd.write(ct_data_path, transformed_ct_data) - nrrd.write(label_path, transformed_label_data) - - print(f"Saved {patient} patients") - subjects.append(tio.Subject(**subject_dict)) - - if preproc.check_preprocessing: - # check preprocessing - info = [patient, transformed_ct_data.shape, transformed_label_data.shape, np.unique(transformed_label_data)] - table.append(info) - - print(f"Completed {len(subjects)} patients") - df = pd.DataFrame(table, columns=['data_name', 'CT_shape', 'label_shape', 'label_unique']) - df.to_csv(os.path.join(save_dir, 'preprocessing.csv'), index=True) - - \ No newline at end of file diff --git a/train.py b/train.py deleted file mode 100644 index 61063e0..0000000 --- a/train.py +++ /dev/null @@ -1,218 +0,0 @@ -import sys -import os -import argparse -import logging -import logging.config -import shutil -import yaml -import pathlib -import builtins -import socket -import random -import time -import json - -import numpy as np -import torch -import torchio as tio -import torch.distributed as dist -import torch.utils.data as data - -from hashlib import shake_256 -from munch import Munch, munchify, unmunchify -from torch import nn -from os import path -from torch.backends import cudnn -from torch.utils.data import DistributedSampler -import wandb - -from utils.TaskFactory import * -from utils.AugmentFactory import * - - -# experiment name -def timehash(): - t = time.time() - t = str(t).encode() - h = shake_256(t) - h = h.hexdigest(5) # output len: 2*5=10 - return h.upper() - -def setup(seed): - torch.manual_seed(seed) - np.random.seed(seed) - random.seed(seed) - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.deterministic = True - -if __name__ == '__main__': - logger = logging.getLogger() - logger.setLevel(logging.INFO) - # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - # os.environ["CUDA_VISIBLE_DEVICES"]= "0" - # Parse arguments - arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("-c", "--config", default="configs/config.yaml", - help="the config file to be used to run the experiment") - arg_parser.add_argument("--verbose", action='store_true', help="Log also to stdout") - arg_parser.add_argument("--debug", default=False, action='store_true', help="debug, no wandb") - args = arg_parser.parse_args() - - # check if the config files exists - if not os.path.exists(args.config): - logging.info("Config file does not exist: {}".format(args.config)) - raise SystemExit - - # Munchify the dict to access entries with both dot notation and ['name'] - logging.info(f'Loading the config file...') - config = yaml.load(open(args.config, "r"), yaml.FullLoader) - config = munchify(config) - - # Setup to be deterministic - logging.info(f'setup to be deterministic') - setup(config.seed) - - # set the title name with timehash - config.title = f'{config.title}_{timehash()}' - - if args.debug: - os.environ['WANDB_DISABLED'] = 'true' - - # start wandb - logging.info(f'setup wandb log ...') - wandb.init( - project="HanSeg", - # entity=config.title, - config=unmunchify(config) - ) - # get run name - run_name = wandb.run.name - - # set run name - wandb.run.name = config.title - wandb.run.save() - - # check if augmentations is set and file exists - logging.info(f'loading augmentations') - augfile = config.data_loader.augmentations - if config.data_loader.augmentations is None: - aug = [] - elif not os.path.exists(config.data_loader.augmentations): - logging.warning(f'Augmentations file does not exist: {config.augmentations}') - aug = [] - else: - with open(config.data_loader.augmentations) as aug_file: - aug = yaml.load(aug_file, yaml.FullLoader) - config.data_loader.augmentations = AugFactory(aug).get_transform() - - logging.info(f'Instantiation of the experiment') - # pdb.set_trace() - experiment = TaskFactory(config, args.debug).get() - logging.info(f'experiment title: {experiment.config.title}') - - project_dir_title = os.path.join(experiment.config.project_dir, experiment.config.experiment.name, 'train', experiment.config.title) - os.makedirs(project_dir_title, exist_ok=True) - logging.info(f'project directory: {project_dir_title}') - - # Setup logger's handlers - file_handler = logging.FileHandler(os.path.join(project_dir_title, 'output.log')) - log_format = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') - file_handler.setFormatter(log_format) - logger.addHandler(file_handler) - - if args.verbose: - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(log_format) - logger.addHandler(stdout_handler) - - # Copy config file to project_dir, to be able to reproduce the experiment - copy_config_path = os.path.join(project_dir_title, 'config.yaml') - copy_augmentations_path = os.path.join(project_dir_title, 'augmentations.yaml') - - shutil.copy(args.config, copy_config_path) - - if not os.path.exists(experiment.config.data_loader.dataset): - logging.error("Dataset path does not exist: {}".format(experiment.config.data_loader.dataset)) - raise SystemExit - - # pre-calculate the checkpoints path - checkpoints_path = path.join(project_dir_title, 'checkpoints') - - if not os.path.exists(checkpoints_path): - os.makedirs(checkpoints_path) - - if experiment.config.trainer.reload and not os.path.exists(experiment.config.trainer.checkpoint): - logging.error(f'Checkpoint file does not exist: {experiment.config.trainer.checkpoint}') - raise SystemExit - - best_val = float('-inf') - best_test = { - 'value': float('-inf'), - 'epoch': -1 - } - - - # Train the model - if config.trainer.do_train: - logging.info('Training...') - assert experiment.epoch < config.trainer.epochs - for epoch in range(experiment.epoch, config.trainer.epochs+1): - epoch_train_loss, epoch_dice = experiment.train() - logging.info(f'Epoch {epoch} Train Dice: {epoch_dice}') - logging.info(f'Epoch {epoch} Train Loss: {epoch_train_loss}') - - val_dice = experiment.test(phase="Validation") - logging.info(f'Epoch {epoch} Val Dice: {val_dice}') - if val_dice < 1e-05 and experiment.epoch > 15: - logging.warning('WARNING: drop in performances detected.') - - optim_name = experiment.optimizer.name - sched_name = experiment.scheduler.name - - if experiment.scheduler is not None: - if optim_name == 'SGD' and sched_name == 'Plateau': - experiment.scheduler.step(val_dice) - elif sched_name == 'SGDR': - experiment.scheduler.step() - else: - experiment.scheduler.step(epoch) - - if epoch % 3 == 0: - test_dice = experiment.test(phase="Test") - logging.info(f'Epoch {epoch} Test Dice: {test_dice}') - - if test_dice > best_test['value']: - best_test['value'] = test_dice - best_test['epoch'] = epoch - - experiment.save('last.pth') - - if val_dice > best_val: - best_val = val_dice - experiment.save('best.pth') - - experiment.epoch += 1 - - logging.info(f''' - Best test Dice found: {best_test['value']} at epoch: {best_test['epoch']} - ''') - - # Test the model - if config.trainer.do_test: - logging.info('Testing the model...') - experiment.load() - test_dice = experiment.test(phase="Test") - logging.info(f'Test results Dice: {test_dice}') - - # Do the inference - if config.trainer.do_inference: - logging.info('Doing inference...') - experiment.load() - output_path = r'experiments/test' - output_path = os.path.join(output_path, config.title) - os.makedirs(output_path, exist_ok=True) - # Copy config file to project_dir, to be able to reproduce the experiment - copy_config_path = os.path.join(output_path, 'config.yaml') - shutil.copy(args.config, copy_config_path) - experiment.inference(output_path=output_path, phase="Test") - \ No newline at end of file diff --git a/utils/AugmentFactory.py b/utils/AugmentFactory.py deleted file mode 100644 index 14b4fb3..0000000 --- a/utils/AugmentFactory.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import torchio as tio - - -class AugFactory: - def __init__(self, aug_list): - self.aug_list = aug_list - self.transforms = self.factory(self.aug_list, []) - logging.info('Augmentations: {}'.format(self.aug_list)) - - def factory(self, auglist, transforms): - if auglist == None: return [] - for aug in auglist: - if aug == 'OneOf': - transforms.append(tio.OneOf(self.factory(auglist[aug], []))) - else: - try: - kwargs = {} - for param, value in auglist[aug].items(): - kwargs[param] = value - else: - t = getattr(tio, aug)(**kwargs) - transforms.append(t) - except: - raise Exception(f"this transform is not valid: {aug}") - return transforms - - def get_transform(self): - """ - return the transform object - :return: - """ - transf = tio.Compose(self.transforms) - return transf \ No newline at end of file diff --git a/utils/EvalFactory.py b/utils/EvalFactory.py deleted file mode 100644 index 2337760..0000000 --- a/utils/EvalFactory.py +++ /dev/null @@ -1,43 +0,0 @@ -from statistics import mean -import torch -import torch.nn.functional as F -import numpy as np -import SimpleITK as sitk -from skimage import metrics -import os -import pandas as pd -import zipfile -from monai.metrics import DiceMetric - -class Eval: - def __init__(self, classes): - self.eps = 1e-06 - self.classes=classes - self.dice_list = [] - - def reset_eval(self): - self.dice_list.clear() - - def calc_dice(self, hard_preds, gt): - # those are B 1 H W D - hard_preds_onehot = torch.nn.functional.one_hot(hard_preds, self.classes).permute(0, 4, 1, 2, 3) - gt_onehot = torch.nn.functional.one_hot(gt.squeeze().long(), num_classes=self.classes) - gt_onehot = gt_onehot.unsqueeze(0) - gt_onehot = torch.movedim(gt_onehot, -1, 1) - cal_dice = DiceMetric(include_background=False, reduction="mean", get_not_nans=True)(hard_preds_onehot, gt_onehot) - metric = cal_dice.mean().item() - return metric - - def add_dice(self, dice): - dice = dice.mean().item() - self.dice_list.append(dice) - - def compute_metrics(self, hard_preds, gt): - dice = self.calc_dice(hard_preds, gt) - self.dice_list.append(dice) - - - def mean_metric(self, phase): - dice = 0 if len(self.dice_list) == 0 else mean(self.dice_list) - self.reset_eval() - return dice diff --git a/utils/TaskFactory.py b/utils/TaskFactory.py deleted file mode 100644 index 5218eb7..0000000 --- a/utils/TaskFactory.py +++ /dev/null @@ -1,46 +0,0 @@ -from utils.TrainFactory import * - -class TaskFactory: - def __init__(self, config, debug=False): - self.name = config.experiment.name - self.config = config - self.debug = debug - - def get(self): - if self.name == 'Segmentation': - experiment = Segmentation(self.config, self.debug) - elif self.name == 'Generation': - experiment = Generation(self.config, self.debug) - elif self.name == 'Anchor': - experiment = Anchor(self.config, self.debug) - else: - raise ValueError(f'Experiment \'{self.name}\' not found') - return experiment - - - - -class Anchor(Experiment): - def __init__(self, config, debug=False): - self.debug = debug - self.train_loader = None - self.test_loader = None - self.val_loader = None - super().__init__(config, self.debug) - - -class Segmentation(Experiment): - def __init__(self, config, debug=False): - self.debug = debug - self.train_loader = None - self.test_loader = None - self.val_loader = None - super().__init__(config, self.debug) - -class Generation(Experiment): - def __init__(self, config, debug=False): - self.debug = debug - self.train_loader = None - self.test_loader = None - self.val_loader = None - super().__init__(config, self.debug) \ No newline at end of file diff --git a/utils/TrainFactory.py b/utils/TrainFactory.py deleted file mode 100644 index 94ca1e1..0000000 --- a/utils/TrainFactory.py +++ /dev/null @@ -1,318 +0,0 @@ -import sys -import os -import argparse -import logging -import logging.config -import yaml -import pathlib -import builtins -import socket -import time -import random -import numpy as np -import torch -import logging -import nrrd -import torchio as tio -import torch.distributed as dist -import torch.utils.data as data -import wandb -import torch.nn.functional as F -from torch import nn -from os import path -from torch.backends import cudnn -from tqdm import tqdm -from torch.utils.data import DataLoader - -from datasets.HaN import HaN -from libs.losses.LossFactory import LossFactory -from libs.losses.LossFactory import * -from libs.models.ModelFactory import ModelFactory -from libs.optimizers.OptimizerFactory import OptimizerFactory -from libs.schedulers.SchedulerFactory import SchedulerFactory -from utils.AugmentFactory import * -from utils.EvalFactory import Eval as Evaluator -from datasets.label_dict import LABEL_dict, Anchor_dict # from datasets/label_dict.py -eps = 1e-10 - - -class Experiment: - def __init__(self, config, debug=False): - self.config = config - self.debug = debug - self.epoch = 0 - self.metrics = {} - self.scaler = torch.cuda.amp.GradScaler() - - # num_classes = len(self.config.data_loader.labels) - # if 'Jaccard' in self.config.loss.name or num_classes == 2: - num_classes = len(Anchor_dict) if self.config.experiment.name == 'Anchor' else len(LABEL_dict) - self.num_classes = num_classes - # load model - model_name = self.config.model.name - in_ch = 2 if self.config.experiment.name == 'Generation' else 1 - - self.model = ModelFactory(model_name, num_classes, in_ch).get().cuda(self.config.device) - for m in self.model.modules(): - for child in m.children(): - if type(child) == torch.nn.BatchNorm3d: - m.eval() - - - wandb.watch(self.model, log_freq=10) - - # load optimizer - optim_name = self.config.optimizer.name - train_params = self.model.parameters() - lr = self.config.optimizer.learning_rate - - self.optimizer = OptimizerFactory(optim_name, train_params, lr).get() - - # load scheduler - sched_name = self.config.lr_scheduler.name - sched_milestones = self.config.lr_scheduler.get('milestones', None) - sched_gamma = self.config.lr_scheduler.get('factor', None) - - self.scheduler = SchedulerFactory( - sched_name, - self.optimizer, - milestones=sched_milestones, - gamma=sched_gamma, - mode='max', - verbose=True, - patience=10 - ).get() - - # load loss - self.loss = LossFactory(self.config.loss.name, classes=num_classes) - - # load evaluator - self.evaluator = Evaluator(classes=num_classes) - if self.config.data_loader.patch_loader: - tranform = tio.Compose([ - tio.CropOrPad(self.config.data_loader.resize_shape, padding_mode=4), - # self.config.data_loader.preprocessing, - self.config.data_loader.augmentations]) - else: - tranform = tio.Compose([ - tio.Resize(self.config.data_loader.resize_shape), - # self.config.data_loader.preprocessing, - self.config.data_loader.augmentations]) - - self.train_dataset = HaN( - config = self.config, - splits='train', - transform=tranform, - sampler=self.config.data_loader.sampler_type - ) - self.val_dataset = HaN( - config = self.config, - transform=tranform, - splits='val', - # transform=self.config.data_loader.preprocessing, - ) - self.test_dataset = HaN( - config = self.config, - transform=tranform, - splits='test', - # transform=self.config.data_loader.preprocessing, - ) - - # queue start loading when used, not when instantiated - self.train_loader = self.train_dataset.get_loader(self.config.data_loader) - self.val_loader = self.val_dataset.get_loader(self.config.data_loader) - self.test_loader = self.test_dataset.get_loader(self.config.data_loader) - - if self.config.trainer.reload: - self.load() - - def save(self, name): - if '.pth' not in name: - name = name + '.pth' - path = os.path.join(self.config.project_dir, self.config.experiment.name, 'train', self.config.title, 'checkpoints', name) - logging.info(f'Saving checkpoint at {path}') - state = { - 'title': self.config.title, - 'epoch': self.epoch, - 'state_dict': self.model.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'metrics': self.metrics, - } - torch.save(state, path) - - def load(self): - path = self.config.trainer.checkpoint - logging.info(f'Loading checkpoint from {path}') - state = torch.load(path) - - if 'title' in state.keys(): - # check that the title headers (without the hash) is the same - self_title_header = self.config.title[:-11] - load_title_header = state['title'][:-11] - if self_title_header == load_title_header: - self.config.title = state['title'] - self.optimizer.load_state_dict(state['optimizer']) - self.model.load_state_dict(state['state_dict']) - self.epoch = state['epoch'] + 1 - - if 'metrics' in state.keys(): - self.metrics = state['metrics'] - - def extract_data_from_feature(self, feature): - ct_volume = feature['ct'][tio.DATA].float().cuda(self.config.device) - # mr_volume = feature['mr'][tio.DATA].float().cuda() - gt = feature['label'][tio.DATA].float().cuda(self.config.device) - # volume = volume/255. # normalization - return ct_volume, gt - - def train(self): - - # self.num_classes - - self.model.train() - self.evaluator.reset_eval() - - data_loader = self.train_loader - losses = [] - for i, d in tqdm(enumerate(data_loader), total=len(data_loader), desc=f'Train epoch {str(self.epoch)}'): - images, gt = self.extract_data_from_feature(d) - - self.optimizer.zero_grad() - with torch.cuda.amp.autocast(): - preds = self.model(images) # pred shape B, C(N), H, W, D - preds_soft = F.softmax(preds, dim=1) - # 이미 여기서 thread: [23,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size -> patch base 일 경우 - # gt shape B, C(N), H, W, D and C(N) -> C(N) is one-hot encoded - gt_onehot = F.one_hot(gt.squeeze(0).long(), num_classes=self.num_classes).permute(0, 4, 1, 2, 3) - assert preds_soft.ndim == gt_onehot.ndim, f'Gt and output dimensions are not the same before loss. {preds_soft.ndim} vs {gt_onehot.ndim}' - # ignore background - preds_soft = preds_soft[:, 1:, ...] - gt_onehot = gt_onehot[:, 1:, ...] - - - if self.loss.names[0] == 'Dice3DLoss': - loss, dice = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot) - else: - loss = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot).cuda(self.config.device) - dice = compute_per_channel_dice(preds_soft, gt_onehot) - try: - losses.append(loss.item()) - except Exception as e: - print(e) - print(loss) - print(loss.item()) - sys.exit() - - self.scaler.scale(loss).backward() - # loss.backward() - self.scaler.step(self.optimizer) - # self.optimizer.step() - self.scaler.update() - - # hard_preds = torch.argmax(preds_soft, dim=1) - # self.evaluator.compute_metrics(hard_preds, gt) - self.evaluator.add_dice(dice=dice) - - epoch_train_loss = sum(losses) / len(losses) - epoch_dice = self.evaluator.mean_metric(phase='Train') - - self.metrics['Train'] = { - 'dice': epoch_dice, - } - - wandb.log({ - f'Epoch': self.epoch, - f'Train/Loss': epoch_train_loss, - f'Train/Dice': epoch_dice, - f'Train/Lr': self.optimizer.param_groups[0]['lr'] - }) - - return epoch_train_loss, epoch_dice - - def test(self, phase): - - self.model.eval() - - with torch.no_grad(): - torch.cuda.empty_cache() - self.evaluator.reset_eval() - losses = [] - - if phase == 'Test': - data_loader = self.test_loader - elif phase == 'Validation': - data_loader = self.val_loader - - for i, d in tqdm(enumerate(data_loader), total=len(data_loader), desc=f'{phase} epoch {str(self.epoch)}'): - images, gt = self.extract_data_from_feature(d) - - preds = self.model(images) - preds_soft = F.softmax(preds, dim=1) - gt_onehot = F.one_hot(gt.squeeze(0).long(), num_classes=self.num_classes).permute(0, 4, 1, 2, 3) - assert preds_soft.ndim == gt_onehot.ndim, f'Gt and output dimensions are not the same before loss. {preds_soft.ndim} vs {gt_onehot.ndim}' - # ignore background - preds_soft = preds_soft[:, 1:, ...] - gt_onehot = gt_onehot[:, 1:, ...] - - if self.loss.names[0] == 'Dice3DLoss': - loss, dice = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot) - else: - dice = compute_per_channel_dice(preds_soft, gt_onehot) - loss = self.loss.losses[self.loss.names[0]](preds_soft, gt_onehot).cuda(self.config.device) - losses.append(loss.item()) - # self.evaluator.compute_metrics(output, gt) - self.evaluator.add_dice(dice=dice) - - epoch_loss = sum(losses) / len(losses) - epoch_dice = self.evaluator.mean_metric(phase=phase) - - wandb.log({ - f'Epoch': self.epoch, - f'{phase}/Loss': epoch_loss, - f'{phase}/Dice': epoch_dice, - }) - - return epoch_dice - - - def inference(self, output_path, phase='Test'): - self.model.eval() - with torch.no_grad(): - torch.cuda.empty_cache() - - if phase == 'Test': - dataset = self.test_dataset - elif phase == 'Validation': - dataset = self.val_dataset - elif phase == 'Train': - dataset = self.train_dataset - - for i, subject in tqdm(enumerate(dataset), total=len(dataset), desc=f'{phase} epoch {str(self.epoch)}'): - os.makedirs(output_path, exist_ok=True) - file_path = os.path.join(output_path, subject.patient+'_pred.seg.nrrd') - # final_shape = subject.data.data[0].shape - if os.path.exists(file_path) and False: - logging.info(f'skipping {subject.patient}...') - continue - - sampler = tio.inference.GridSampler( - subject, - self.config.data_loader.patch_shape, - 0 - ) - loader = DataLoader(sampler, batch_size=self.config.data_loader.batch_size) - aggregator = tio.inference.GridAggregator(sampler, overlap_mode='hann') - - - for j, patch in enumerate(loader): - images = patch['ct'][tio.DATA].float().cuda(self.config.device) - preds = self.model(images) - aggregator.add_batch(preds, patch[tio.LOCATION]) - - output = aggregator.get_output_tensor() - output_soft = F.softmax(output, dim=1) - hard_output = torch.argmax(output_soft, dim=0) - # hard_output = hard_output.squeeze(0) - output = hard_output.detach().cpu().numpy() - nrrd.write(file_path, np.uint8(output)) - logging.info(f'patient {subject.patient} completed, {file_path}.') \ No newline at end of file