From 2aa83bc8e1f2dc80b1ddcfbebf77f2060a7f5bd1 Mon Sep 17 00:00:00 2001 From: Christian Herz Date: Wed, 30 Jun 2021 16:06:59 -0400 Subject: [PATCH] ENH: added MONAILabel app for tricuspid valve segmentation (issue #1) - currently no training is supported ref: https://github.com/Project-MONAI/MONAILabel/issues/154 --- .../segmentation_tricuspid_valve/README.md | 3 + .../segmentation_tricuspid_valve/info.yaml | 22 +++ .../lib/__init__.py | 4 + .../lib/activelearning.py | 26 +++ .../segmentation_tricuspid_valve/lib/infer.py | 69 +++++++ .../lib/transforms.py | 131 +++++++++++++ .../segmentation_tricuspid_valve/lib/vnet.py | 183 ++++++++++++++++++ .../segmentation_tricuspid_valve/main.py | 62 ++++++ .../requirements.txt | 0 .../segmentation_tricuspid_valve/test.py | 4 + 10 files changed, 504 insertions(+) create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/README.md create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/info.yaml create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/lib/__init__.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/lib/infer.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/lib/vnet.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/main.py create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/requirements.txt create mode 100644 MONAILabel-app/segmentation_tricuspid_valve/test.py diff --git a/MONAILabel-app/segmentation_tricuspid_valve/README.md b/MONAILabel-app/segmentation_tricuspid_valve/README.md new file mode 100644 index 0000000..aedf5af --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/README.md @@ -0,0 +1,3 @@ +# Segmentation - Tricuspid Valve from 3DE + +## Overview diff --git a/MONAILabel-app/segmentation_tricuspid_valve/info.yaml b/MONAILabel-app/segmentation_tricuspid_valve/info.yaml new file mode 100644 index 0000000..0f14bcb --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/info.yaml @@ -0,0 +1,22 @@ +--- +version: 1 +name: Segmentation Tricuspid Valve +description: MONAI Label App for segmentation of the tricuspid valve from 3DE images +dimension: 3 +labels: + - anterior + - posterior + - septal +config: + infer: + device: cuda + train: + name: model_01 + pretrained: True + device: cuda + amp: true + lr: 0.02 + epochs: 200 + val_split: 0.1 + train_batch_size: 8 + val_batch_size: 8 diff --git a/MONAILabel-app/segmentation_tricuspid_valve/lib/__init__.py b/MONAILabel-app/segmentation_tricuspid_valve/lib/__init__.py new file mode 100644 index 0000000..7fed96a --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/lib/__init__.py @@ -0,0 +1,4 @@ +from .activelearning import MyStrategy +from .infer import MyInfer +# from .train import MyTrain +from .vnet import VNet \ No newline at end of file diff --git a/MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py b/MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py new file mode 100644 index 0000000..a5f55cb --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/lib/activelearning.py @@ -0,0 +1,26 @@ +import logging + +from monailabel.interfaces import Datastore +from monailabel.interfaces.tasks import Strategy + +logger = logging.getLogger(__name__) + + +class MyStrategy(Strategy): + """ + Consider implementing a first strategy for active learning + """ + + def __init__(self): + super().__init__("Get First Sample") + + def __call__(self, request, datastore: Datastore): + images = datastore.get_unlabeled_images() + if not len(images): + return None + + images.sort() + image = images[0] + + logger.info(f"First: Selected Image: {image}") + return image diff --git a/MONAILabel-app/segmentation_tricuspid_valve/lib/infer.py b/MONAILabel-app/segmentation_tricuspid_valve/lib/infer.py new file mode 100644 index 0000000..6d04bbb --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/lib/infer.py @@ -0,0 +1,69 @@ +from monai.inferers import SimpleInferer +from monai.engines.utils import CommonKeys as Keys + +from monai.transforms import ( + AddChanneld, + LoadImaged, + ToTensord, + ScaleIntensityd, + AsDiscreted, + ConcatItemsd, + ToNumpyd, + SqueezeDimd +) + +from monailabel.utils.others.post import Restored +from monailabel.interfaces.tasks import InferTask, InferType + +from .transforms import DistanceTransformd + + +class MyInfer(InferTask): + """ + This provides Inference Engine for pre-trained tricuspid valve segmentation (VNet) model. + """ + + def __init__( + self, + path, + network=None, + type=InferType.SEGMENTATION, + labels=("anterior", "posterior", "septal"), + dimension=3, + description="A pre-trained model for volumetric (3D) segmentation of tricuspid valve from 3DE image", + ): + super().__init__( + path=path, + network=network, + type=type, + labels=labels, + dimension=dimension, + description=description, + ) + + def pre_transforms(self): + all_keys = [Keys.IMAGE, Keys.LABEL] + return [ + LoadImaged(keys=all_keys, reader="NibabelReader"), + AddChanneld(keys=all_keys), + DistanceTransformd(keys=[Keys.LABEL]), + ScaleIntensityd( + keys=[Keys.IMAGE], + minv=0.0, + maxv=1.0 + ), + ToTensord(keys=all_keys), + ConcatItemsd(keys=all_keys, name=Keys.IMAGE, dim=0) + ] + + def inferer(self): + return SimpleInferer() + + def post_transforms(self): + return [ + AddChanneld(keys="pred"), + AsDiscreted(keys="pred", argmax=True), + SqueezeDimd(keys="pred", dim=0), + ToNumpyd(keys="pred"), + Restored(keys="pred", ref_image="image"), + ] \ No newline at end of file diff --git a/MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py b/MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py new file mode 100644 index 0000000..5ccb8ae --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/lib/transforms.py @@ -0,0 +1,131 @@ +import logging +from monai.transforms import MapTransform +import SimpleITK as sitk +import numpy as np +import torch + +logger = logging.getLogger(__name__) + + +def simplex(t, axis: int = 1) -> bool: + import torch + _sum = t.sum(axis).type(torch.float32) + _ones = torch.ones_like(_sum, dtype=torch.float32) + return torch.allclose(_sum, _ones) + + +def is_one_hot(t, axis=1) -> bool: + return simplex(t, axis) and sset(t, [0, 1]) + + +def sset(a, sub) -> bool: + return uniq(a).issubset(sub) + + +def uniq(a) -> set: + import torch + return set(torch.unique(a.cpu()).numpy()) + + +class OneHotTransform(object): + + @classmethod + def run(cls, data): + if len(data.shape) == 4: + assert data.shape[0] == 1 + data = data[0] + + n_classes = (len(np.unique(data))) + assert n_classes > 1, f"{cls.__name__}: Not enough unique pixel values found in data." + assert n_classes < 10, f"{cls.__name__}: Too many unique pixel values found in data." + + w, h, d = data.shape + res = np.stack([data == c for c in range(n_classes)], axis=0).astype(np.int32) + assert res.shape == (n_classes, w, h, d) + assert np.all(res.sum(axis=0) == 1) + return res + + def __init__(self, fields): + self.fields = fields + + def __call__(self, data): + for field in self.fields: + data[field] = self.run(data[field]) + assert np.isfinite(data[field]).all() + return data + + +class OneHotTransformd(MapTransform): + + def __init__(self, keys): + super(OneHotTransformd, self).__init__(keys) + + def __call__(self, data): + for key in self.keys: + one_hot = OneHotTransform.run(data[key]) + assert np.isfinite(one_hot).all() + assert np.any(one_hot) + + data[key] = one_hot.astype(np.float32) + return data + + +class DistanceTransform(object): + """ Create distance map on the fly for labels + """ + + METHODS = { + "SDM": sitk.SignedMaurerDistanceMapImageFilter, + "EDM": sitk.DanielssonDistanceMapImageFilter + } + DEFAULT_METHOD = "SDM" + + @classmethod + def get_distance_map(cls, data, method=DEFAULT_METHOD): + image = sitk.GetImageFromArray(data.astype(np.int16)) + distanceMapFilter = cls.METHODS[method]() + distanceMapFilter.SetUseImageSpacing(True) + distanceMapFilter.SetSquaredDistance(False) + out = distanceMapFilter.Execute(image) + return sitk.GetArrayFromImage(out) + + def __init__(self, fields, method=DEFAULT_METHOD): + self.fields = fields + self.computationMethod = method + + def __call__(self, data): + for field in self.fields: + d = data[field] + assert is_one_hot(torch.Tensor(d), axis=0) + # NB: skipping computation of background distance map + d = d[1:, ...] + assert d.shape[0] > 0 + data[field] = np.stack([ + self.get_distance_map(d[ch].astype(np.float32), self.computationMethod) for ch in range(d.shape[0])], + axis=0) + assert np.isfinite(data[field]).all() + return data + + +class DistanceTransformd(MapTransform): + + def one_hot_to_dist(self, input_array): + assert is_one_hot(torch.Tensor(input_array), axis=0) + out = np.stack( + [DistanceTransform.get_distance_map(input_array[ch].astype(np.float32), + method=self.method) for ch in range(input_array.shape[0])], axis=0) + return out + + def __init__(self, keys, method=DistanceTransform.DEFAULT_METHOD): + super(DistanceTransformd, self).__init__(keys) + self.method = method + + def __call__(self, data): + for key in self.keys: + one_hot = OneHotTransform.run(data[key]) + assert np.isfinite(one_hot).all() + assert np.any(one_hot) + + result_np = self.one_hot_to_dist(one_hot).astype(np.float32) + data[key] = result_np[1:, ...] + return data \ No newline at end of file diff --git a/MONAILabel-app/segmentation_tricuspid_valve/lib/vnet.py b/MONAILabel-app/segmentation_tricuspid_valve/lib/vnet.py new file mode 100644 index 0000000..ede2cb7 --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/lib/vnet.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn + + +class ResidualConvBlock(nn.Module): + + def __init__(self, n_stages, n_filters_in, n_filters_out, normalization='batchnorm', activation="ReLU", + kernel_size=3, padding=1, + expand_chan=False): + super(ResidualConvBlock, self).__init__() + + activation = getattr(nn.modules.activation, activation) + + self.expand_chan = expand_chan + if self.expand_chan: + ops = [nn.Conv3d(n_filters_in, n_filters_out, 1)] + if normalization == 'batchnorm': + ops.append(nn.BatchNorm3d(n_filters_out)) + ops.append(activation()) + self.conv_expan = nn.Sequential(*ops) + + ops = [] + for i in range(n_stages): + if normalization != 'none': + ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size, padding=padding)) + if normalization == 'batchnorm': + ops.append(nn.BatchNorm3d(n_filters_out)) + else: + ops.append(nn.Conv3d(n_filters_in, n_filters_out, kernel_size, padding=padding)) + + ops.append(activation(inplace=True)) + + self.conv = nn.Sequential(*ops) + + def forward(self, x): + if self.expand_chan: + x = self.conv(x) + self.conv_expan(x) + else: + x = (self.conv(x) + x) + return x + + +class DownsamplingConvBlock(nn.Module): + + def __init__(self, n_filters_in, n_filters_out, normalization='batchnorm', activation="ReLU", stride=2, padding=0): + super(DownsamplingConvBlock, self).__init__() + + activation = getattr(nn.modules.activation, activation) + + ops = [] + if normalization != 'none': + ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) + if normalization == 'batchnorm': + ops.append(nn.BatchNorm3d(n_filters_out)) + else: + ops.append(nn.Conv3d(n_filters_in, n_filters_out, stride, padding=padding, stride=stride)) + + ops.append(activation(inplace=True)) + + self.conv = nn.Sequential(*ops) + + def forward(self, x): + x = self.conv(x) + return x + + +class UpsamplingDeconvBlock(nn.Module): + + def __init__(self, n_filters_in, n_filters_out, normalization='batchnorm', activation="ReLU", stride=2): + super(UpsamplingDeconvBlock, self).__init__() + + activation = getattr(nn.modules.activation, activation) + + ops = [] + if normalization != 'none': + ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) + if normalization == 'batchnorm': + ops.append(nn.BatchNorm3d(n_filters_out)) + else: + ops.append(nn.ConvTranspose3d(n_filters_in, n_filters_out, stride, padding=0, stride=stride)) + + ops.append(activation(inplace=True)) + + self.conv = nn.Sequential(*ops) + + def forward(self, x): + x = self.conv(x) + return x + + +class VNETEncoder(nn.Module): + + def __init__(self, n_channels, n_filters=16, depth=5, normalization='batchnorm', activation="ReLU", kernel_size=3, + padding=1): + + super(VNETEncoder, self).__init__() + + self.depth = depth + + def _pow(l): + return 2 ** l + + for level in range(self.depth): + n_repeats = max(min(level + 1, 3), 1) + + if level == 0: + temp = ResidualConvBlock(n_repeats, n_channels, n_filters, normalization, activation, kernel_size, + padding, expand_chan=n_channels > 1) + setattr(self, "block_{}_enc".format(level + 1), temp) + else: + temp = ResidualConvBlock(n_repeats, n_filters * _pow(level), n_filters * _pow(level), normalization, + activation, kernel_size, padding) + setattr(self, "block_{}_enc".format(level + 1), temp) + + if level < self.depth - 1: + temp = DownsamplingConvBlock(n_filters * _pow(level), n_filters * _pow(level + 1), normalization, + activation) + setattr(self, "block_{}_dw".format(level + 1), temp) + + def forward(self, x): + encoder = dict() + for level in range(1, self.depth + 1): + x = x if level == 1 else encoder["x{}_dw".format(level - 1)] + encoder["x{}_enc".format(level)] = getattr(self, "block_{}_enc".format(level))(x) + if level < self.depth: + encoder["x{}_dw".format(level)] = getattr(self, "block_{}_dw".format(level))( + encoder["x{}_enc".format(level)]) + return encoder + + +class VNETDecoder(nn.Module): + + def __init__(self, n_classes, n_filters=16, depth=5, normalization='batchnorm', activation="ReLU", kernel_size=3, + padding=1): + super(VNETDecoder, self).__init__() + + self.depth = depth + + def _pow(l): + return 2 ** l + + for level in range(self.depth, 0, -1): + n_repeats = max(min(level, 3), 1) + + if level < self.depth: + temp = ResidualConvBlock(n_repeats, n_filters * _pow(level - 1), n_filters * _pow(level - 1), + normalization, activation, kernel_size, padding) + setattr(self, "block_{}_dec".format(level), temp) + + if level > 1: + temp = UpsamplingDeconvBlock(n_filters * _pow(level - 1), n_filters * _pow(level - 2), normalization, + activation) + setattr(self, "block_{}_up".format(level), temp) + + self.out_conv = nn.Conv3d(n_filters, n_classes, kernel_size, padding=padding) + + def forward(self, encoder): + decoder = dict() + x = None + for level in range(self.depth, 0, -1): + x = encoder["x{}_enc".format(level)] if level == self.depth else decoder["x{}_up".format(level + 1)] + if level < self.depth: + x = getattr(self, "block_{}_dec".format(level))(x) + + if level > 1: + x = getattr(self, "block_{}_up".format(level))(x) + decoder["x{}_up".format(level)] = x + encoder["x{}_enc".format(level - 1)] + + out_logits = self.out_conv(x) + return torch.nn.functional.softmax(out_logits.reshape(out_logits.size(0), out_logits.size(1), -1), dim=1).view_as(out_logits) + + +class VNet(nn.Module): + + def __init__(self, n_channels, n_classes, n_filters=16, depth=5, normalization='batchnorm', activation="ReLU", + kernel_size=3, padding=1): + super(VNet, self).__init__() + self.encoder = VNETEncoder(n_channels, n_filters, depth, normalization, activation, kernel_size, padding) + self.decoder = VNETDecoder(n_classes, n_filters, depth, normalization, activation, kernel_size, padding) + + def forward(self, x): + encoder = self.encoder(x) + return self.decoder(encoder) diff --git a/MONAILabel-app/segmentation_tricuspid_valve/main.py b/MONAILabel-app/segmentation_tricuspid_valve/main.py new file mode 100644 index 0000000..e69733e --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/main.py @@ -0,0 +1,62 @@ +import json +import logging +import os + +from lib import MyInfer, MyStrategy, VNet + +from monailabel.interfaces import MONAILabelApp +from monailabel.utils.activelearning import Random + +logger = logging.getLogger(__name__) + + +class MyApp(MONAILabelApp): + def __init__(self, app_dir, studies): + self.model_dir = os.path.join(app_dir, "model") + # TODO: depending on selected model, a different network needs to be selected + self.network = VNet( + n_channels=2, + n_classes=4, + n_filters=16, + normalization="batchnorm" + ) + + self.pretrained_model = os.path.join(self.model_dir, "segmentation_tricuspid_valve.pt") + self.final_model = os.path.join(self.model_dir, "final.pt") + self.train_stats_path = os.path.join(self.model_dir, "train_stats.json") + + path = [self.pretrained_model, self.final_model] + infers = { + "segmentation_tricuspid_valve": MyInfer(path, self.network), + } + + strategies = { + "random": Random(), + "first": MyStrategy(), + } + + resources = [ + ( + self.pretrained_model, + # "https://api.ngc.nvidia.com/v2/models/nvidia/med" + # "/clara_pt_liver_and_tumor_ct_segmentation/versions/1/files/models/model.pt", + ), + ] + + super().__init__( + app_dir=app_dir, + studies=studies, + infers=infers, + strategies=strategies, + resources=resources, + ) + + def train(self, request): + pass + + def train_stats(self): + + if os.path.exists(self.train_stats_path): + with open(self.train_stats_path, "r") as fc: + return json.load(fc) + return super().train_stats() diff --git a/MONAILabel-app/segmentation_tricuspid_valve/requirements.txt b/MONAILabel-app/segmentation_tricuspid_valve/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/MONAILabel-app/segmentation_tricuspid_valve/test.py b/MONAILabel-app/segmentation_tricuspid_valve/test.py new file mode 100644 index 0000000..0d5234a --- /dev/null +++ b/MONAILabel-app/segmentation_tricuspid_valve/test.py @@ -0,0 +1,4 @@ +from monailabel.interfaces.test import test_main + +if __name__ == "__main__": + test_main()